CS236605: Deep Learning on Computational Accelerators

Homework Assignment 3

Faculty of Computer Science, Technion.

Submitted by:

# Name Id email
Student 1 Yair Glassman 305782112 glassman@campus.technion.ac.il
Student 2 Martina Vanelli 921180428 martina@campus.technion.ac.il

Introduction

In this assignment we'll learn to generate text with a deep multilayer RNN network based on GRU cells. Then we'll focus our attention on image generation and implement two different generative models: A variational autoencoder and a generative adversarial network.

General Guidelines

  • Please read the getting started page on the course website. It explains how to setup, run and submit the assignment.
  • This assignment requires running on GPU-enabled hardware. Please read the course servers usage guide. It explains how to use and run your code on the course servers to benefit from training with GPUs.
  • The text and code cells in these notebooks are intended to guide you through the assignment and help you verify your solutions. The notebooks do not need to be edited at all (unless you wish to play around). The only exception is to fill your name(s) in the above cell before submission. Please do not remove sections or change the order of any cells.
  • All your code (and even answers to questions) should be written in the files within the python package corresponding the assignment number (hw1, hw2, etc). You can of course use any editor or IDE to work on these files.

$$ \newcommand{\mat}[1]{\boldsymbol {#1}} \newcommand{\mattr}[1]{\boldsymbol {#1}^\top} \newcommand{\matinv}[1]{\boldsymbol {#1}^{-1}} \newcommand{\vec}[1]{\boldsymbol {#1}} \newcommand{\vectr}[1]{\boldsymbol {#1}^\top} \newcommand{\rvar}[1]{\mathrm {#1}} \newcommand{\rvec}[1]{\boldsymbol{\mathrm{#1}}} \newcommand{\diag}{\mathop{\mathrm {diag}}} \newcommand{\set}[1]{\mathbb {#1}} \newcommand{\norm}[1]{\left\lVert#1\right\rVert} \newcommand{\pderiv}[2]{\frac{\partial #1}{\partial #2}} \newcommand{\bb}[1]{\boldsymbol{#1}} $$

Part 1: Sequence Models

In this part we will learn about working with text sequences using recurrent neural networks. We'll go from a raw text file all the way to a fully trained GRU-RNN model and generate works of art!

In [2]:
import unittest
import os
import sys
import pathlib
import urllib
import shutil
import re

import numpy as np
import torch
import matplotlib.pyplot as plt

%load_ext autoreload
%autoreload 2

test = unittest.TestCase()
plt.rcParams.update({'font.size': 12})
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
Using device: cuda

Text generation with a char-level RNN

Obtaining the corpus

Let's begin by downloading a corpus containing all the works of William Shakespeare. Since he was very prolific, this corpus is fairly large and will provide us with enough data for obtaining impressive results.

In [3]:
CORPUS_URL = 'https://github.com/cedricdeboom/character-level-rnn-datasets/raw/master/datasets/shakespeare.txt'
DATA_DIR = pathlib.Path.home().joinpath('.pytorch-datasets')

def download_corpus(out_path=DATA_DIR, url=CORPUS_URL, force=False):
    pathlib.Path(out_path).mkdir(exist_ok=True)
    out_filename = os.path.join(out_path, os.path.basename(url))
    
    if os.path.isfile(out_filename) and not force:
        print(f'Corpus file {out_filename} exists, skipping download.')
    else:
        print(f'Downloading {url}...')
        with urllib.request.urlopen(url) as response, open(out_filename, 'wb') as out_file:
            shutil.copyfileobj(response, out_file)
        print(f'Saved to {out_filename}.')
    return out_filename
    
corpus_path = download_corpus()
Corpus file /home/glassman/.pytorch-datasets/shakespeare.txt exists, skipping download.

Load the text into memory and print a snippet:

In [4]:
with open(corpus_path, 'r') as f:
    corpus = f.read()

print(f'Corpus length: {len(corpus)} chars')
print(corpus[7:1234])
Corpus length: 6347703 chars
ALLS WELL THAT ENDS WELL

by William Shakespeare

Dramatis Personae

  KING OF FRANCE
  THE DUKE OF FLORENCE
  BERTRAM, Count of Rousillon
  LAFEU, an old lord
  PAROLLES, a follower of Bertram
  TWO FRENCH LORDS, serving with Bertram

  STEWARD, Servant to the Countess of Rousillon
  LAVACHE, a clown and Servant to the Countess of Rousillon
  A PAGE, Servant to the Countess of Rousillon

  COUNTESS OF ROUSILLON, mother to Bertram
  HELENA, a gentlewoman protected by the Countess
  A WIDOW OF FLORENCE.
  DIANA, daughter to the Widow

  VIOLENTA, neighbour and friend to the Widow
  MARIANA, neighbour and friend to the Widow

  Lords, Officers, Soldiers, etc., French and Florentine  

SCENE:
Rousillon; Paris; Florence; Marseilles

ACT I. SCENE 1.
Rousillon. The COUNT'S palace

Enter BERTRAM, the COUNTESS OF ROUSILLON, HELENA, and LAFEU, all in black

  COUNTESS. In delivering my son from me, I bury a second husband.
  BERTRAM. And I in going, madam, weep o'er my father's death anew;
    but I must attend his Majesty's command, to whom I am now in
    ward, evermore in subjection.
  LAFEU. You shall find of the King a husband, madam; you, sir, a
    father. He that so generally is at all times good must of
    

Data Preprocessing

The first thing we'll need is to map from each unique character in the corpus to an index that will represent it in our learning process.

TODO: Implement the char_maps() function in the hw3/charnn.py module.

In [5]:
import hw3.charnn as charnn

char_to_idx, idx_to_char = charnn.char_maps(corpus)
print(char_to_idx)

test.assertEqual(len(char_to_idx), len(idx_to_char))
test.assertSequenceEqual(list(char_to_idx.keys()), list(idx_to_char.values()))
test.assertSequenceEqual(list(char_to_idx.values()), list(idx_to_char.keys()))
{'\n': 0, ' ': 1, '!': 2, '"': 3, '$': 4, '&': 5, "'": 6, '(': 7, ')': 8, ',': 9, '-': 10, '.': 11, '0': 12, '1': 13, '2': 14, '3': 15, '4': 16, '5': 17, '6': 18, '7': 19, '8': 20, '9': 21, ':': 22, ';': 23, '<': 24, '?': 25, 'A': 26, 'B': 27, 'C': 28, 'D': 29, 'E': 30, 'F': 31, 'G': 32, 'H': 33, 'I': 34, 'J': 35, 'K': 36, 'L': 37, 'M': 38, 'N': 39, 'O': 40, 'P': 41, 'Q': 42, 'R': 43, 'S': 44, 'T': 45, 'U': 46, 'V': 47, 'W': 48, 'X': 49, 'Y': 50, 'Z': 51, '[': 52, ']': 53, '_': 54, 'a': 55, 'b': 56, 'c': 57, 'd': 58, 'e': 59, 'f': 60, 'g': 61, 'h': 62, 'i': 63, 'j': 64, 'k': 65, 'l': 66, 'm': 67, 'n': 68, 'o': 69, 'p': 70, 'q': 71, 'r': 72, 's': 73, 't': 74, 'u': 75, 'v': 76, 'w': 77, 'x': 78, 'y': 79, 'z': 80, '}': 81, '\ufeff': 82}

Seems we have some strange characters in the corpus that are very rare and are probably due to mistakes. To reduce the length of each tensor we'll need to later represent our chars, it's best to remove them.

TODO: Implement the remove_chars() function in the hw3/charnn.py module.

In [6]:
corpus, n_removed = charnn.remove_chars(corpus, ['}','$','_','<','\ufeff'])
print(f'Removed {n_removed} chars')

# After removing the chars, re-create the mappings
char_to_idx, idx_to_char = charnn.char_maps(corpus)
Removed 34 chars

The next thing we need is an embedding of the chracters. An embedding is a representation of each token from the sequence as a tensor. For a char-level RNN, our tokens will be chars and we can thus use the simplest possible embedding: encode each char as a one-hot tensor. In other words, each char will be represented as a tensor whos length is the total number of unique chars (V) which contains all zeros except at the index corresponding to that specific char.

TODO: Implement the functions chars_to_onehot() and onehot_to_chars() in the hw3/charnn.py module.

In [7]:
# Wrap the actual embedding functions for calling convenience
def embed(text):
    return charnn.chars_to_onehot(text, char_to_idx)

def unembed(embedding):
    return charnn.onehot_to_chars(embedding, idx_to_char)

text_snippet = corpus[3104:3148]
print(text_snippet)
print(embed(text_snippet[0:3]))

test.assertEqual(text_snippet, unembed(embed(text_snippet)))
test.assertEqual(embed(text_snippet).dtype, torch.int8)
brine a maiden can season her praise in.
   
tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0,
         0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0]], dtype=torch.int8)

Dataset Creation

We wish to train our model to generate text by constantly predicting what the next char should be based on the past. To that end we'll need to train our recurrent network in a way similar to a classification task. At each timestep, we input a char and set the expected output (label) to be the next char in the original sequence.

We will split our corpus into shorter sequences of length S chars (try to think why; see question below). Each sample we provide our model with will therefore be a tensor of shape (S,V) where V is the embedding dimension. Our model will operate sequentially on each char in the sequence. For each sample, we'll also need a label. This is simple another sequence, shifted by one char so that the label of each char is the next char in the corpus.

TODO: Implement the chars_to_labelled_samples() function in the hw3/charnn.py module.

In [8]:
# Create dataset of sequences
seq_len = 64
vocab_len = len(char_to_idx)

# Create labelled samples
samples, labels = charnn.chars_to_labelled_samples(corpus, char_to_idx, seq_len, device)
print(f'samples shape: {samples.shape}')
print(f'labels shape: {labels.shape}')

# Test shapes
num_samples = (len(corpus) - 1) // seq_len
test.assertEqual(samples.shape, (num_samples, seq_len, vocab_len))
test.assertEqual(labels.shape, (num_samples, seq_len))

# Test content
for _ in range(1000):
    # random sample
    i = np.random.randint(num_samples, size=(1,))[0]
    # Compare to corpus
    test.assertEqual(unembed(samples[i]), corpus[i*seq_len:(i+1)*seq_len], msg=f"content mismatch in sample {i}")
    # Compare to labels
    sample_text = unembed(samples[i])
    label_text = str.join('', [idx_to_char[j.item()] for j in labels[i]])
    test.assertEqual(sample_text[1:], label_text[0:-1], msg=f"label mismatch in sample {i}")
    
print(f'sample 100 as text:\n{unembed(samples[100])}')
samples shape: torch.Size([99182, 64, 78])
labels shape: torch.Size([99182, 64])
sample 100 as text:
nity, though valiant in the
    defence, yet is weak. Unfold to 

As usual, instead of feeding one sample as a time into our model's forward we'll work with batches of samples. This means that at every timestep, our model will operate on a batch of chars that are from different sequences. Effectively this will allow us to parallelize training our model by dong matrix-matrix multiplications instead of matrix-vector during the forward pass.

Let's use the standard PyTorch Dataset/DataLoader combo. Luckily for the dataset we can use a built-in class, TensorDataset to return tuples of (sample, label) from the samples and labels tensors we created above.

In [9]:
import torch.utils.data

# Create DataLoader returning batches of samples.
batch_size = 32

ds_corpus = torch.utils.data.TensorDataset(samples, labels)
dl_corpus = torch.utils.data.DataLoader(ds_corpus, batch_size=batch_size, shuffle=False)

Let's see what that gives us:

In [10]:
print(f'num batches: {len(dl_corpus)}')

x0, y0 = next(iter(dl_corpus))
print(f'shape of a batch sample: {x0.shape}')
print(f'shape of a batch label: {y0.shape}')
num batches: 3100
shape of a batch sample: torch.Size([32, 64, 78])
shape of a batch label: torch.Size([32, 64])

Model Implementation

Finally, our data set is ready so we can focus on our model.

We'll implement here is a multilayer gated recurrent unit (GRU) model, with dropout. This model is a type of RNN which performs similar to the well-known LSTM model, but it's somewhat easier to train because it has less parameters. We'll modify the regular GRU slightly by applying dropout to the hidden states passed between layers of the model.

The model accepts an input $\mat{X}\in\set{R}^{S\times V}$ containing a sequence of embedded chars. It returns an output $\mat{Y}\in\set{R}^{S\times V}$ of predictions for the next char and the final hidden state $\mat{H}\in\set{R}^{L\times H}$. Here $S$ is the sequence length, $V$ is the vocabulary size (number of unique chars), $L$ is the number of layers in the model and $H$ is the hidden dimension.

Mathematically, the model's forward function at layer $k\in[1,L]$ and timestep $t\in[1,S]$ can be described as

$$ \begin{align} \vec{z_t}^{[k]} &= \sigma\left(\vec{x}^{[k]}_t {\mattr{W}_{\mathrm{xz}}}^{[k]} + \vec{h}_{t-1}^{[k]} {\mattr{W}_{\mathrm{hz}}}^{[k]} + \vec{b}_{\mathrm{z}}^{[k]}\right) \\ \vec{r_t}^{[k]} &= \sigma\left(\vec{x}^{[k]}_t {\mattr{W}_{\mathrm{xr}}}^{[k]} + \vec{h}_{t-1}^{[k]} {\mattr{W}_{\mathrm{hr}}}^{[k]} + \vec{b}_{\mathrm{r}}^{[k]}\right) \\ \vec{g_t}^{[k]} &= \tanh\left(\vec{x}^{[k]}_t {\mattr{W}_{\mathrm{xg}}}^{[k]} + (\vec{r_t}^{[k]}\odot\vec{h}_{t-1}^{[k]}) {\mattr{W}_{\mathrm{hg}}}^{[k]} + \vec{b}_{\mathrm{g}}^{[k]}\right) \\ \vec{h_t}^{[k]} &= \vec{z}^{[k]}_t \odot \vec{h}^{[k]}_{t-1} + \left(1-\vec{z}^{[k]}_t\right)\odot \vec{g_t}^{[k]} \end{align} $$

The input to each layer is, $$ \mat{X}^{[k]} = \begin{bmatrix} {\vec{x}_1}^{[k]} \ \vdots \ {\vec{x}_S}^{[k]}

\end{bmatrix}

\begin{cases} \mat{X} & \mathrm{if} ~k = 1~ \\ \mathrm{dropout}_p \left( \begin{bmatrix} {\vec{h}_1}^{[k-1]} \\ \vdots \\ {\vec{h}_S}^{[k-1]} \end{bmatrix} \right) & \mathrm{if} ~1 < k \leq L+1~ \end{cases}. $$

The output of the entire model is then, $$ \mat{Y} = \mat{X}^{[L+1]} {\mattr{W}_{\mathrm{hy}}} + \mat{B}_{\mathrm{y}} $$

and the final hidden state is $$ \mat{H} = \begin{bmatrix} {\vec{h}_S}^{[1]} \\ \vdots \\ {\vec{h}_S}^{[L]} \end{bmatrix}. $$

Notes:

  • $t\in[1,S]$ is the timestep, i.e. the current position within the sequence of each sample.
  • $\vec{x}_t^{[k]}$ is the input of layer $k$ at timestep $t$, respectively.
  • The outputs of the last layer $\vec{y}_t^{[L]}$, are the predicted next characters for every input char. These are similar to class scores in classification tasks.
  • The hidden states at the last timestep, $\vec{h}_S^{[k]}$, are the final hidden state returned from the model.
  • $\sigma(\cdot)$ is the sigmoid function, i.e. $\sigma(\vec{z}) = 1/(1+e^{-\vec{z}})$ which returns values in $(0,1)$.
  • $\tanh(\cdot)$ is the hyperbolic tangent, i.e. $\tanh(\vec{z}) = (e^{2\vec{z}}-1)/(e^{2\vec{z}}+1)$ which returns values in $(-1,1)$.
  • $\vec{h_t}^{[k]}$ is the hidden state of layer $k$ at time $t$. This can be thought of as the memory of that layer.
  • $\vec{g_t}^{[k]}$ is the candidate hidden state for time $t+1$.
  • $\vec{z_t}^{[k]}$ is known as the update gate. It combines the previous state with the input to determine how much the current state will be combined with the new candidate state. For example, if $\vec{z_t}^{[k]}=\vec{1}$ then the current input has no effect on the output.
  • $\vec{r_t}^{[k]}$ is known as the reset gate. It combines the previous state with the input to determine how much of the previous state will affect the current state candidate. For example if $\vec{r_t}^{[k]}=\vec{0}$ the previous state has no effect on the current candidate state.

Here's a graphical representation of the GRU's forward pass at each timestep. The $\vec{\tilde{h}}$ in the image is our $\vec{g}$ (candidate next state).

You can see how the reset and update gates allow the model to completely ignore it's previous state, completely ignore it's input, or any mixture of those states (since the gates are actually continuous and between $(0,1)$).

Here's a graphical representation of the entire model. You can ignore the $c_t^{[k]}$ (cell state) variables (which are relevant for LSTM models). Our model has only the hidden state, $h_t^{[k]}$. Also notice that we added dropout between layers (the up arrows).

The purple tensors are inputs (a sequence and initial hidden state per layer), and the green tensors are outputs (another sequence and final hidden state per layer). Each blue block implements the above forward equations. Blocks that are on the same vertical level are at the same layer, and therefore share parameters.

TODO: Implement the MultilayerGRU class in the hw3/charnn.py module.

Notes:

  • You'll need to handle input batches now. The math is identical to the above, but all the tensors will have an extra batch dimension as their first dimension.
  • Use the diagram above to help guide your implementation. It will help you visualize what shapes to returns where, etc.
In [11]:
in_dim = vocab_len
h_dim = 256
n_layers = 2
model = charnn.MultilayerGRU(in_dim, h_dim, out_dim=in_dim, n_layers=n_layers)
model = model.to(device)
print(model)

# Test forward pass
y, h = model(x0.to(dtype=torch.float))
print(f'y.shape={y.shape}')
print(f'h.shape={h.shape}')

test.assertEqual(y.shape, (batch_size, seq_len, vocab_len))
test.assertEqual(h.shape, (batch_size, n_layers, h_dim))
test.assertEqual(len(list(model.parameters())), 9 * n_layers + 2) 
MultilayerGRU(
  (param_l0_p0): Linear(in_features=78, out_features=1, bias=True)
  (param_l0_p1): Linear(in_features=256, out_features=1, bias=False)
  (param_l0_p2): Linear(in_features=78, out_features=1, bias=True)
  (param_l0_p3): Linear(in_features=256, out_features=1, bias=False)
  (param_l0_p4): Linear(in_features=78, out_features=256, bias=True)
  (param_l0_p5): Linear(in_features=256, out_features=256, bias=False)
  (param_l0_p6): Dropout(p=0)
  (param_l1_p0): Linear(in_features=256, out_features=1, bias=True)
  (param_l1_p1): Linear(in_features=256, out_features=1, bias=False)
  (param_l1_p2): Linear(in_features=256, out_features=1, bias=True)
  (param_l1_p3): Linear(in_features=256, out_features=1, bias=False)
  (param_l1_p4): Linear(in_features=256, out_features=256, bias=True)
  (param_l1_p5): Linear(in_features=256, out_features=256, bias=False)
  (param_l1_p6): Dropout(p=0)
  (weights_hy): Linear(in_features=256, out_features=78, bias=True)
)
y.shape=torch.Size([32, 64, 78])
h.shape=torch.Size([32, 2, 256])

Generating text by sampling

Now that we have a model, we can implement text generation based on it. The idea is simple: At each timestep our model receives one char $x_t$ from the input sequence and outputs scores $y_t$ for what the next char should be. We'll convert these scores into a probability over each of the possible chars. In other words, for each input char $x_t$ we create a probability distribution for the next char conditioned on the current one and the state of the model (representing all previous inputs): $$p(x_{t+1}|x_t; \vec{h}_t).$$

Once we have such a distribution, we'll sample a char from it. This will be the first char of our generated sequence. Now we can feed this new char into the model, create another distribution, sample the next char and so on. Note that it's crucial to propagate the hidden state when sampling.

The important point however is how to create the distribution from the scores. One way, as we saw in previous ML tasks, is to use the softmax function. However, a drawback of softmax is that it can generate very diffuse (more uniform) distributions if the score values are very similar. When sampling, we would prefer to control the distributions and make them less uniform to increase the chance of sampling the char(s) with the highest scores compared to the others.

To control the variance of the distribution, a common trick is to add a hyperparameter $T$, known as the temperature to the softmax function. The class scores are simply scaled by $T$ before softmax is applied: $$ \mathrm{softmax}_T(\vec{y}) = \frac{e^{\vec{y}/T}}{\sum_k e^{y_k/T}} $$

A low $T$ will result in less uniform distributions and vice-versa.

TODO: Implement the hot_softmax() function in the hw3/charnn.py module.

In [12]:
scores = y[0,0,:].detach()
_, ax = plt.subplots(figsize=(15,5))

for t in reversed([0.3, 0.5, 1.0, 100]):
    ax.plot(charnn.hot_softmax(scores, temperature=t).cpu().numpy(), label=f'T={t}')
ax.set_xlabel('$x_{t+1}$')
ax.set_ylabel('$p(x_{t+1}|x_t)$')
ax.legend()

uniform_proba = 1/len(char_to_idx)
uniform_diff = torch.abs(charnn.hot_softmax(scores, temperature=100) - uniform_proba)
test.assertTrue(torch.all(uniform_diff < 1e-4))

TODO: Implement the generate_from_model() function in the hw3/charnn.py module.

In [13]:
for _ in range(3):
    text = charnn.generate_from_model(model, "foobar", 50, (char_to_idx, idx_to_char), T=0.5)
    print(text)
    test.assertEqual(len(text), 50)
foobard&cIfp)1:'lD0fQ4'(lLpmM?&&2Ay85.NbGV?.qVU:4r
foobarZpDfe0Z2uBcsP
!
YP6sF?ByY R6MXv"?emNkh2d'GXi
foobarHd ff
W]7 K6SkjCPHJJ6jkg6SrdD
H)yk6d(2" krHr

Training

To train such a model, we'll calculate the loss at each time step by comparing the predicted char to the actual char from our label. We can use cross entropy since per char it's similar to a classification problem. We'll then sum the losses over the sequence and back-propagate the gradients though time. Notice that the back-propagation algorithm will "visit" each layer's parameter tensors multiple times, so we'll accumulate gradients in parameters of the blocks. Luckily autograd will handle this part for us.

As usual, the first step of training will be to try and overfit a large model (many parameters) to a tiny dataset. Again, this is to ensure the model and training code are implemented correctly, i.e. that the model can learn.

For a generative model such as this, overfitting is slightly trickier than for for classification. What we'll aim to do is to get our model to memorize a specific sequence of chars, so that when given the first char in the sequence it will immediately spit out the rest of the sequence verbatim.

Let's create a tiny dataset to memorize.

In [14]:
# Pick a tiny subset of the dataset
subset_start, subset_end = 1001, 1005
ds_corpus_ss = torch.utils.data.Subset(ds_corpus, range(subset_start, subset_end))
dl_corpus_ss = torch.utils.data.DataLoader(ds_corpus_ss, batch_size=1, shuffle=False)

# Convert subset to text
subset_text = ''
for i in range(subset_end - subset_start):
    subset_text += unembed(ds_corpus_ss[i][0])
print(f'Text to "memorize":\n\n{subset_text}')
Text to "memorize":

TRAM. What would you have?
  HELENA. Something; and scarce so much; nothing, indeed.
    I would not tell you what I would, my lord.
    Faith, yes:
    Strangers and foes do sunder and not kiss.
  BERTRAM. I pray you, stay not, but in haste to horse.
  HE

Now let's implement the first part of our training code.

TODO: Implement the train_epoch() and train_batch() methods of the RNNTrainer class in the hw3/training.py module. Note: Think about how to correctly handle the hidden state of the model between batches and epochs (for this specific task, i.e. text generation).

In [15]:
import torch.nn as nn
import torch.optim as optim
from hw3.training import RNNTrainer

torch.manual_seed(42)

lr = 0.01
num_epochs = 500

in_dim = vocab_len
h_dim = 128
n_layers = 2
loss_fn = nn.CrossEntropyLoss()
model = charnn.MultilayerGRU(in_dim, h_dim, out_dim=in_dim, n_layers=n_layers).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
trainer = RNNTrainer(model, loss_fn, optimizer, device)

for epoch in range(num_epochs):
    epoch_result = trainer.train_epoch(dl_corpus_ss, verbose=False)
    
    # Every X epochs, we'll generate a sequence starting from the first char in the first sequence
    # to visualize how/if/what the model is learning.
    if epoch == 0 or (epoch+1) % 25 == 0:
        avg_loss = np.mean(epoch_result.losses)
        accuracy = np.mean(epoch_result.accuracy)
        print(f'\nEpoch #{epoch+1}: Avg. loss = {avg_loss:.3f}, Accuracy = {accuracy:.2f}%')
        
        generated_sequence = charnn.generate_from_model(model, subset_text[0],
                                                        seq_len*(subset_end-subset_start),
                                                        (char_to_idx,idx_to_char), T=0.1)
        # Stop if we've successfully memorized the small dataset.
        print(generated_sequence)
        if generated_sequence == subset_text:
            break

# Test successful overfitting
test.assertGreater(epoch_result.accuracy, 99)
test.assertEqual(generated_sequence, subset_text)
Epoch #1: Avg. loss = 3.819, Accuracy = 18.75%
Tos                                                        t                           o                t               t                     t    o                                                              t                                             

Epoch #25: Avg. loss = 0.230, Accuracy = 92.58%
TRAM. What I would not kiss.
    Faith, yes:
    Faith, yes:
    I would not kiss.
    Faith, yes:
    I would not in haste to horse.
    Faith, yes:
    I would not kiss.
    Faith, yes:
    Faith, yes:
    I would not in haste to horse.
    Faith, yes:
 

Epoch #50: Avg. loss = 0.025, Accuracy = 99.22%
TRAM. What would not tell you what I would not tell you what I would not tell you what I would not tell you what I would not tell you what I would not tell you what I would not tell you what I would not tell you what I would not tell you what I would not t

Epoch #75: Avg. loss = 0.053, Accuracy = 98.05%
TRAM. What would you what I would you what I would not tell you what I would not tell you what I would you what I would you what I would you what I would you what I would you what I would you what I would you what I would you what I would you what I would 

Epoch #100: Avg. loss = 0.003, Accuracy = 100.00%
TRAM. What would you have?
  HELENA. Something; and scarce so much; nothing, indeed.
    I would not tell you what I would, my lord.
    Faith, yes:
    Strangers and foes do sunder and not kiss.
  BERTRAM. I pray you, stay not, but in haste to horse.
  HE

OK, so training works - we can memorize a short sequence. Next on the agenda is to split our full dataset into a training and test sets of batched sequences.

In [16]:
# Full dataset definition
vocab_len = len(char_to_idx)
seq_len = 64
batch_size = 256
train_test_ratio = 0.9
num_samples = (len(corpus) - 1) // seq_len
num_train = int(train_test_ratio * num_samples)

samples, labels = charnn.chars_to_labelled_samples(corpus, char_to_idx, seq_len, device)

ds_train = torch.utils.data.TensorDataset(samples[:num_train], labels[:num_train])
dl_train = torch.utils.data.DataLoader(ds_train, batch_size=batch_size, shuffle=False, drop_last=True)

ds_test = torch.utils.data.TensorDataset(samples[num_train:], labels[num_train:])
dl_test = torch.utils.data.DataLoader(ds_test, batch_size=batch_size, shuffle=False, drop_last=True)

print(f'Train: {len(dl_train):3d} batches, {len(dl_train)*batch_size*seq_len:7d} chars')
print(f'Test:  {len(dl_test):3d} batches, {len(dl_test)*batch_size*seq_len:7d} chars')
Train: 348 batches, 5701632 chars
Test:   38 batches,  622592 chars

We'll now train a much larger model on our large dataset. You'll need a GPU for this part.

The code blocks below will train the model and save checkpoints containing the training state and the best model parameters to a file. This allows you to stop training and resume it later from where you left.

Note that you can use the main.py script provided within the assignment folder to run this notebook from the command line as if it were a python script by using the run-nb subcommand. This allows you to train your model using this notebook without starting jupyter. You can combine this with srun or sbatch to run the notebook with a GPU on the course servers.

In [17]:
# Full training definition
lr = 0.001
num_epochs = 50

in_dim = out_dim = vocab_len
hidden_dim = 512
n_layers = 3
dropout = 0.5
checkpoint_file = 'checkpoints/rnn'
max_batches = 300
early_stopping = 5

model = charnn.MultilayerGRU(in_dim, hidden_dim, out_dim, n_layers, dropout)
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=2, verbose=True)
trainer = RNNTrainer(model, loss_fn, optimizer, device)

TODO:

  • Implement the fit() method of the Trainer class. You can reuse the implementation from HW2, but make sure to implement early stopping and checkpoints.
  • Implement the test_epoch() and test_batch() methods of the RNNTrainer class in the hw3/training.py module.
  • Run the following block to train.
In [18]:
from cs236605.plot import plot_fit

def post_epoch_fn(epoch, test_res, train_res, verbose):
    # Update learning rate
    scheduler.step(test_res.accuracy)
    # Sample from model to show progress
    if verbose:
        start_seq = "ACT I."
        generated_sequence = charnn.generate_from_model(
            model, start_seq, 100, (char_to_idx,idx_to_char), T=0.5
        )
        print(generated_sequence)

# Train, unless final checkpoint is found
checkpoint_file_final = f'{checkpoint_file}_final.pt'
if os.path.isfile(checkpoint_file_final):
    print(f'*** Loading final checkpoint file {checkpoint_file_final} instead of training')
    saved_state = torch.load(checkpoint_file_final, map_location=device)
    model.load_state_dict(saved_state['model_state'])
else:
    try:
        # Print pre-training sampling
        print(charnn.generate_from_model(model, "ACT I.", 100, (char_to_idx,idx_to_char), T=0.5))

        fit_res = trainer.fit(dl_train, dl_test, num_epochs, max_batches=max_batches,
                              post_epoch_fn=post_epoch_fn, early_stopping=early_stopping,
                              checkpoints=checkpoint_file, print_every=1)
        
        fig, axes = plot_fit(fit_res)
    except KeyboardInterrupt as e:
        print('\n *** Training interrupted by user')
ACT I.nXg,;ryvqq]
LE3ySiA!FV4iN?CbYB4d0Ul&91ePz PZW:PsC61Af0stFh.tebq:!ZmdpNvxXI7U1r!Fxwv;yo: 7TVBC

*** Loading checkpoint file checkpoints/rnn.pt
--- EPOCH 1/50 ---
train_batch (Avg. Loss 1.541, Accuracy 55.0): 100%|██████████| 348/348 [01:20<00:00,  4.28it/s]
test_batch (Avg. Loss 1.547, Accuracy 53.3): 100%|██████████| 38/38 [00:03<00:00, 12.37it/s]
ACT I.

GLOUCESTER:
No, and my seise would not poor friends and thy word of late to speak the treaso
--- EPOCH 2/50 ---
train_batch (Avg. Loss 1.548, Accuracy 54.8): 100%|██████████| 348/348 [01:21<00:00,  4.31it/s]
test_batch (Avg. Loss 1.552, Accuracy 53.2): 100%|██████████| 38/38 [00:03<00:00, 11.94it/s]
ACT I.

KING RICHARD III:
I have we shall not us he pust with the fear of a more father's will of th
--- EPOCH 3/50 ---
train_batch (Avg. Loss 1.551, Accuracy 54.7): 100%|██████████| 348/348 [01:28<00:00,  3.98it/s]
test_batch (Avg. Loss 1.550, Accuracy 53.2): 100%|██████████| 38/38 [00:03<00:00, 11.37it/s]
ACT I.

DUKE OF SYRACUSE:
But when the sealless than the prayers and men in some men of the souls an
--- EPOCH 4/50 ---
train_batch (Avg. Loss 1.551, Accuracy 54.7): 100%|██████████| 348/348 [01:27<00:00,  4.04it/s]
test_batch (Avg. Loss 1.552, Accuracy 53.2): 100%|██████████| 38/38 [00:03<00:00, 11.53it/s]
Epoch     3: reducing learning rate of group 0 to 5.0000e-04.
ACT I.

KING RICHARD II:
Go the father's fear shall not so you think the power of the part of the mo
--- EPOCH 5/50 ---
train_batch (Avg. Loss 1.541, Accuracy 54.9): 100%|██████████| 348/348 [01:26<00:00,  4.04it/s]
test_batch (Avg. Loss 1.542, Accuracy 53.2): 100%|██████████| 38/38 [00:03<00:00, 11.64it/s]
ACT I.
Be not so the dead,
Be a state
He were the world that comes him so the lands to see the great
--- EPOCH 6/50 ---
train_batch (Avg. Loss 1.536, Accuracy 55.1): 100%|██████████| 348/348 [01:25<00:00,  4.08it/s]
test_batch (Avg. Loss 1.541, Accuracy 53.2): 100%|██████████| 38/38 [00:03<00:00, 11.71it/s]
ACT I.
Bear your will of the saper for the fight more free than the truth to die so run stay in the 
--- EPOCH 7/50 ---
train_batch (Avg. Loss 1.535, Accuracy 55.1): 100%|██████████| 348/348 [01:24<00:00,  3.80it/s]
test_batch (Avg. Loss 1.540, Accuracy 53.2): 100%|██████████| 38/38 [00:03<00:00, 10.73it/s]
ACT I.

LORD LORD MARIA be like all the field for the fortune that down and a good brother shall be 
--- EPOCH 8/50 ---
train_batch (Avg. Loss 1.534, Accuracy 55.1): 100%|██████████| 348/348 [01:28<00:00,  3.92it/s]
test_batch (Avg. Loss 1.541, Accuracy 53.1): 100%|██████████| 38/38 [00:03<00:00, 11.33it/s]
ACT I.
What see the sands and the book of the great a state.

KING RICHARD III:
I do of this such a 
--- EPOCH 9/50 ---
train_batch (Avg. Loss 1.534, Accuracy 55.2): 100%|██████████| 348/348 [01:27<00:00,  4.03it/s]
test_batch (Avg. Loss 1.540, Accuracy 53.2): 100%|██████████| 38/38 [00:03<00:00, 11.62it/s]
ACT I.

DUKE OF YORK:
My tongue to the great country be with his deserts to content of the gods of h
--- EPOCH 10/50 ---
train_batch (Avg. Loss 1.533, Accuracy 55.1): 100%|██████████| 348/348 [01:28<00:00,  3.75it/s]
test_batch (Avg. Loss 1.540, Accuracy 53.2): 100%|██████████| 38/38 [00:03<00:00, 11.04it/s]
ACT I.

KING RICHARD III:
The more in the life, and the prisoner that I have call me to shall entert
--- EPOCH 11/50 ---
train_batch (Avg. Loss 1.532, Accuracy 55.2): 100%|██████████| 348/348 [01:30<00:00,  3.91it/s]
test_batch (Avg. Loss 1.541, Accuracy 53.2): 100%|██████████| 38/38 [00:03<00:00, 11.31it/s]
ACT I.
And there is the present hand before the foolish cause of the vengeance of the kingdom, both 
--- EPOCH 12/50 ---
train_batch (Avg. Loss 1.532, Accuracy 55.2): 100%|██████████| 348/348 [01:28<00:00,  3.92it/s]
test_batch (Avg. Loss 1.541, Accuracy 53.3): 100%|██████████| 38/38 [00:03<00:00, 11.36it/s]
ACT I.

CORIOLANUS:
O more than the matter to the sought and fear,
The dear sake his father,
And be 
--- EPOCH 13/50 ---
train_batch (Avg. Loss 1.534, Accuracy 55.1): 100%|██████████| 348/348 [01:28<00:00,  3.97it/s]
test_batch (Avg. Loss 1.540, Accuracy 53.3): 100%|██████████| 38/38 [00:03<00:00, 11.47it/s]
ACT I.
Best that the beard to the court of the service with eyes of the thoughts and dangerous and b
--- EPOCH 14/50 ---
train_batch (Avg. Loss 1.533, Accuracy 55.2): 100%|██████████| 348/348 [01:27<00:00,  3.97it/s]
test_batch (Avg. Loss 1.540, Accuracy 53.3): 100%|██████████| 38/38 [00:03<00:00, 11.48it/s]
ACT I.
Why, the brother than he hath heard the company to the beither with my slaves, which the part
--- EPOCH 15/50 ---
train_batch (Avg. Loss 1.531, Accuracy 55.2): 100%|██████████| 348/348 [01:27<00:00,  4.00it/s]
test_batch (Avg. Loss 1.539, Accuracy 53.3): 100%|██████████| 38/38 [00:03<00:00, 11.41it/s]
ACT I.
The lords,
When he must be deserved with his heart the proper than the state have done the gr
--- EPOCH 16/50 ---
train_batch (Avg. Loss 1.531, Accuracy 55.2): 100%|██████████| 348/348 [01:24<00:00,  4.13it/s]
test_batch (Avg. Loss 1.540, Accuracy 53.3): 100%|██████████| 38/38 [00:03<00:00, 11.77it/s]
ACT I.
The lady shall have the plot of the good lord, the sand fortune that I do not see the poor we
--- EPOCH 17/50 ---
train_batch (Avg. Loss 1.531, Accuracy 55.2): 100%|██████████| 348/348 [01:23<00:00,  4.20it/s]
test_batch (Avg. Loss 1.539, Accuracy 53.3): 100%|██████████| 38/38 [00:03<00:00, 12.05it/s]
ACT I.
Where is the care in the sight of my father, if he was a man shall bear the wars,
To the this
--- EPOCH 18/50 ---
train_batch (Avg. Loss 1.532, Accuracy 55.2): 100%|██████████| 348/348 [01:23<00:00,  4.23it/s]
test_batch (Avg. Loss 1.540, Accuracy 53.3): 100%|██████████| 38/38 [00:03<00:00, 11.91it/s]
Epoch    17: reducing learning rate of group 0 to 2.5000e-04.
ACT I.

KING RICHARD III:
The love so long man strange that is our love, and the coutin and some tho
--- EPOCH 19/50 ---
train_batch (Avg. Loss 1.529, Accuracy 55.3): 100%|██████████| 348/348 [01:23<00:00,  4.21it/s]
test_batch (Avg. Loss 1.536, Accuracy 53.4): 100%|██████████| 38/38 [00:03<00:00, 12.06it/s]
ACT I. Thou do thee to the son,
That see them be so not she will as the man down the matter he of th
--- EPOCH 20/50 ---
train_batch (Avg. Loss 1.524, Accuracy 55.4): 100%|██████████| 348/348 [01:23<00:00,  4.17it/s]
test_batch (Avg. Loss 1.536, Accuracy 53.5): 100%|██████████| 38/38 [00:03<00:00, 12.10it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 20
ACT I.
The part were to never makes thee and scorn to the sparells of the most rass and the man deli
--- EPOCH 21/50 ---
train_batch (Avg. Loss 1.522, Accuracy 55.4): 100%|██████████| 348/348 [01:23<00:00,  4.04it/s]
test_batch (Avg. Loss 1.534, Accuracy 53.5): 100%|██████████| 38/38 [00:03<00:00, 12.18it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 21
ACT I.
The bind thee are here that is the change and the bed country me for his soul so fair office,
--- EPOCH 22/50 ---
train_batch (Avg. Loss 1.522, Accuracy 55.4): 100%|██████████| 348/348 [01:32<00:00,  3.60it/s]
test_batch (Avg. Loss 1.535, Accuracy 53.5): 100%|██████████| 38/38 [00:03<00:00, 11.71it/s]
ACT I.
The tents of my son of the words,
And that I am so to see the particular sent and warrant to 
--- EPOCH 23/50 ---
train_batch (Avg. Loss 1.521, Accuracy 55.4): 100%|██████████| 348/348 [01:23<00:00,  4.23it/s]
test_batch (Avg. Loss 1.535, Accuracy 53.5): 100%|██████████| 38/38 [00:03<00:00, 12.21it/s]
ACT I.
And the noble lord, and the son well in the soul,
The so not my son of this brother say the g
--- EPOCH 24/50 ---
train_batch (Avg. Loss 1.521, Accuracy 55.5): 100%|██████████| 348/348 [01:23<00:00,  4.16it/s]
test_batch (Avg. Loss 1.535, Accuracy 53.5): 100%|██████████| 38/38 [00:03<00:00, 11.97it/s]
ACT I.
The world, I will love all the order of the best and country bloody truth of the action of th
--- EPOCH 25/50 ---
train_batch (Avg. Loss 1.521, Accuracy 55.5): 100%|██████████| 348/348 [01:25<00:00,  3.94it/s]
test_batch (Avg. Loss 1.534, Accuracy 53.4): 100%|██████████| 38/38 [00:03<00:00, 11.47it/s]
ACT I.
The great a death, therefore desire the mistress of my mind,
The lines of my cannot be some s
--- EPOCH 26/50 ---
train_batch (Avg. Loss 1.520, Accuracy 55.5): 100%|██████████| 348/348 [01:28<00:00,  3.92it/s]
test_batch (Avg. Loss 1.533, Accuracy 53.5): 100%|██████████| 38/38 [00:03<00:00, 11.43it/s]
ACT I. There will 'lord that he hath receive the looks of the hour against your most rest,
And the s
--- EPOCH 27/50 ---
train_batch (Avg. Loss 1.520, Accuracy 55.5): 100%|██████████| 348/348 [01:27<00:00,  3.98it/s]
test_batch (Avg. Loss 1.534, Accuracy 53.5): 100%|██████████| 38/38 [00:03<00:00, 11.20it/s]
ACT I.
Believe the world should say shall I make him not so many prince of the gallows to the death 
--- EPOCH 28/50 ---
train_batch (Avg. Loss 1.520, Accuracy 55.5): 100%|██████████| 348/348 [01:28<00:00,  3.92it/s]
test_batch (Avg. Loss 1.534, Accuracy 53.4): 100%|██████████| 38/38 [00:03<00:00, 11.50it/s]
ACT I.
What are the trite!

KING RICHARD III:
Now so the son and a state;
My grace of grace of the b
--- EPOCH 29/50 ---
train_batch (Avg. Loss 1.520, Accuracy 55.5): 100%|██████████| 348/348 [01:28<00:00,  3.97it/s]
test_batch (Avg. Loss 1.534, Accuracy 53.5): 100%|██████████| 38/38 [00:03<00:00, 11.43it/s]
ACT I. While the poir than in the most selves to the world, and great rest and my soul, then wherein
--- EPOCH 30/50 ---
train_batch (Avg. Loss 1.521, Accuracy 55.5): 100%|██████████| 348/348 [01:28<00:00,  3.80it/s]
test_batch (Avg. Loss 1.534, Accuracy 53.5): 100%|██████████| 38/38 [00:03<00:00, 10.22it/s]
ACT I.
And since thou wilt be better than the man of black now is the trimp of the matter to the swe
--- EPOCH 31/50 ---
train_batch (Avg. Loss 1.520, Accuracy 55.5): 100%|██████████| 348/348 [01:31<00:00,  3.89it/s]
test_batch (Avg. Loss 1.535, Accuracy 53.4): 100%|██████████| 38/38 [00:03<00:00, 11.28it/s]
Epoch    30: reducing learning rate of group 0 to 1.2500e-04.
ACT I.
The dead, all the father's head the grace of his friends and thy ground:
My lord,
And should 
--- EPOCH 32/50 ---
train_batch (Avg. Loss 1.520, Accuracy 55.5): 100%|██████████| 348/348 [01:29<00:00,  3.90it/s]
test_batch (Avg. Loss 1.545, Accuracy 52.4): 100%|██████████| 38/38 [00:03<00:00, 11.23it/s]
ACT I.

KING RICHARD III.
For the way before the sand beat accuse him not the charge the man of the 
--- EPOCH 33/50 ---
train_batch (Avg. Loss 1.517, Accuracy 55.6): 100%|██████████| 348/348 [01:29<00:00,  3.90it/s]
test_batch (Avg. Loss 1.545, Accuracy 52.5): 100%|██████████| 38/38 [00:03<00:00, 11.30it/s]
ACT I.
The heart as sight, and this is the court of the part of death of his noble soul of the world
--- EPOCH 34/50 ---
train_batch (Avg. Loss 1.517, Accuracy 55.6): 100%|██████████| 348/348 [01:29<00:00,  3.90it/s]
test_batch (Avg. Loss 1.545, Accuracy 52.5): 100%|██████████| 38/38 [00:03<00:00, 11.25it/s]
ACT I. Speak to heir my sake and place my lord to the great action of this strange talk to the speak
--- EPOCH 35/50 ---
train_batch (Avg. Loss 1.517, Accuracy 55.6): 100%|██████████| 348/348 [01:29<00:00,  3.92it/s]
test_batch (Avg. Loss 1.545, Accuracy 52.4): 100%|██████████| 38/38 [00:03<00:00, 11.28it/s]
ACT I.
And therefore then we single master shall be so for my hand, see the sand and the more the Ki
--- EPOCH 36/50 ---
train_batch (Avg. Loss 1.516, Accuracy 55.6): 100%|██████████| 348/348 [01:29<00:00,  3.88it/s]
test_batch (Avg. Loss 1.545, Accuracy 52.4): 100%|██████████| 38/38 [00:03<00:00, 11.32it/s]
ACT I.
The son, and enter the sand and more than the senses of what shall be bears to him the tent;

--- EPOCH 37/50 ---
train_batch (Avg. Loss 1.516, Accuracy 55.6): 100%|██████████| 348/348 [01:28<00:00,  3.92it/s]
test_batch (Avg. Loss 1.545, Accuracy 52.4): 100%|██████████| 38/38 [00:03<00:00, 11.36it/s]
ACT I.
And my lord, and to her to answer of brother than see the crowns in the gird of the rest that
--- EPOCH 38/50 ---
train_batch (Avg. Loss 1.515, Accuracy 55.6): 100%|██████████| 348/348 [01:28<00:00,  3.98it/s]
test_batch (Avg. Loss 1.545, Accuracy 52.5): 100%|██████████| 38/38 [00:03<00:00, 11.44it/s]
ACT I.
And have not to my flattering the best against the bark me?

For your state and a word in the
--- EPOCH 39/50 ---
train_batch (Avg. Loss 1.515, Accuracy 55.6): 100%|██████████| 348/348 [01:29<00:00,  3.89it/s]
test_batch (Avg. Loss 1.544, Accuracy 52.4): 100%|██████████| 38/38 [00:03<00:00, 11.24it/s]
ACT I.
The country,
That we will be my condition is the prove my subject that my son, and so much th
--- EPOCH 40/50 ---
train_batch (Avg. Loss 1.515, Accuracy 55.6): 100%|██████████| 348/348 [01:28<00:00,  3.93it/s]
test_batch (Avg. Loss 1.544, Accuracy 52.5): 100%|██████████| 38/38 [00:03<00:00, 11.47it/s]
ACT I.
The world when I will stands him to not done,
    And the proper sight shall think your own w
--- EPOCH 41/50 ---
train_batch (Avg. Loss 1.515, Accuracy 55.6): 100%|██████████| 348/348 [01:28<00:00,  3.82it/s]
test_batch (Avg. Loss 1.545, Accuracy 52.4): 100%|██████████| 38/38 [00:03<00:00, 11.16it/s]
ACT I.
Be not the gracious beard my tongue, let me not serve me, and we be good more than that the w
--- EPOCH 42/50 ---
train_batch (Avg. Loss 1.515, Accuracy 55.6): 100%|██████████| 348/348 [01:29<00:00,  3.87it/s]
test_batch (Avg. Loss 1.544, Accuracy 52.4): 100%|██████████| 38/38 [00:03<00:00, 11.11it/s]
ACT I.
The steel and hath leave thee to the country's love.
    Therefore then she can of the sea of
--- EPOCH 43/50 ---
train_batch (Avg. Loss 1.515, Accuracy 55.6): 100%|██████████| 348/348 [01:29<00:00,  3.86it/s]
test_batch (Avg. Loss 1.544, Accuracy 52.5): 100%|██████████| 38/38 [00:03<00:00, 11.14it/s]
ACT I. And of the partious father so straight had for the first matter that so? My lord.
    If he t
--- EPOCH 44/50 ---
train_batch (Avg. Loss 1.515, Accuracy 55.6): 100%|██████████| 348/348 [01:29<00:00,  3.92it/s]
test_batch (Avg. Loss 1.544, Accuracy 52.5): 100%|██████████| 38/38 [00:03<00:00, 11.31it/s]
Epoch    43: reducing learning rate of group 0 to 6.2500e-05.
ACT I.
Where is good strange reverend before the constant fair prosperity in fortune have you so som
--- EPOCH 45/50 ---
train_batch (Avg. Loss 1.515, Accuracy 55.6): 100%|██████████| 348/348 [01:28<00:00,  3.93it/s]
test_batch (Avg. Loss 1.560, Accuracy 52.2): 100%|██████████| 38/38 [00:03<00:00, 11.35it/s]
ACT I. She say me to a should be my stone
    To the beards the King would be come to my servant of 
--- EPOCH 46/50 ---
train_batch (Avg. Loss 1.513, Accuracy 55.7): 100%|██████████| 348/348 [01:28<00:00,  3.91it/s]
test_batch (Avg. Loss 1.561, Accuracy 52.2): 100%|██████████| 38/38 [00:03<00:00, 11.31it/s]
ACT I.
And there is more to be possess of the great service shall be her hand of a bloody part of th
--- EPOCH 47/50 ---
train_batch (Avg. Loss 1.512, Accuracy 55.7): 100%|██████████| 348/348 [01:28<00:00,  3.92it/s]
test_batch (Avg. Loss 1.561, Accuracy 52.2): 100%|██████████| 38/38 [00:03<00:00, 11.32it/s]
ACT I. She is to do for her country's mother's as desperate here,
    I know night.
  PANDARUS. Have
--- EPOCH 48/50 ---
train_batch (Avg. Loss 1.512, Accuracy 55.7): 100%|██████████| 348/348 [01:28<00:00,  3.90it/s]
test_batch (Avg. Loss 1.561, Accuracy 52.2): 100%|██████████| 38/38 [00:03<00:00, 11.33it/s]

Generating a work of art

Armed with our fully trained model, let's generate the next Hamlet! You should experiment with modifying the sampling temperature and see what happens.

TODO: Specify the generation parameters in the part1_generation_params() function within the hw3/answers.py module.

In [25]:
import hw3.answers

start_seq, temperature = hw3.answers.part1_generation_params()

generated_sequence = charnn.generate_from_model(
    model, start_seq, 10000, (char_to_idx,idx_to_char), T=temperature
)

print(generated_sequence)
Once upon a time,
    And here, cousin Lord Sir himself to the heaven.
    Is the best is his and given him in a stoods to the body to hear my master's cold here, some to see the cast on the world will be provodus and father's mother,
    And hast thou nothing done,
    That flourish.

  PORTIA. I have the ring not the good wrong,
    To mine end
    May be my mother,
    And they are not as our stand'st thou will not be should great country than the chhate, my dobesing,
    I have some a
    Thing do so
    The pent of griefs, I will dead every compand and farewell,
    And I go, my lordsemour shall be mading, and then they be the hand of his hand.

Third Margaret, when it is one will comes the head of
    a led the wars
To seem for him in the life of the number I will not cannot be cruel the peace is mind,
    And whose moved to be ready esteem,
                            [They as we reason shall be a man it all turn;
Boy of his beelle to make you;
That you have love a great death for the and you to the death.
  FORD. An  and heaven fahe with him.
    What very noble better for the high been tower,
    And I may!
  FIRST GENTLEMAN. How betire him what she are not son should but the first for the very bed the place me requies of the dear
    to them given this day and flourish and sentence
    The Prince before that as court shall be them out, my lord, with his Lord Percy,
    And but ground;
    And they cannot proid of my griefs,
    When you heard them about the purse.
    The cause of down presently of Brutus I am but a red present.

CORIOLANUS:
Who cannot that man that so much the man of heaven and faints make their fair and guest to mine ear,
    Wherefore of my stall
    That strange than the day,
I should have news as in your honour,
    With the part, and this and say it be mack and wars in consent and violent part with pabered.
  Ade respect
    The world;
    And this dind mad,
    And see-the state;
    And then,
For I have all worthiest complexion of this like to make a mother?
  BASSANIO. That by a and the hind;
    And you say the gentle breawh the party against the sight!
  BOLINGBROKE:
This shall been children none;
    And that they she say you have no man his senses your hand of the prince of France with the sorrow hath seen to carry thee out of the charm I, that makes this
    with in by late,
    To all officers and dear man to the Hamber;
    And ever is so, I am merry, the deep to the words we may, but therefore is the hours come so land.
  PORTIA and BOTINGIRI. Marry, let me seem'd then as me in villain, nor the blood.
  BRUTUS. I that cannot we so be the proper
    To be stome with a sigh of him,
    And fair,
    As I cannot assure.
    Shall I that down to love other man.
    The now that be rough his matter would not read the our great achieves,
And they will be so rure and soul,
    I'll unfell my sake his good sovereign wanton of my end of the person.
    Thoy the every resalithed by by the more speak of her before;
    For you to do a friends, my lord.
  WILLIAMS:
I'll follow you and content to the chamcess
    Did have me something love again.
  BASSANIO. Thou wilt be straight is here and be dead to come no good man speak with the pence
    Than hards to-morrow for forsee in sir, by these former than the rate!
    There you will hear them a good constinticulous knows of the lord and care and their world,
    My praytrealous mind,
    With their that art thou not anengs of them,
    And good call not as love and cravour;
    And can think I will go made out of the brotherwhere and most long!
  SIR TOBY. I think you draw my graatus-more then with the Emperor.
  BATTARDIO. What, she will bold,
    That thou hast patient lady,
    Or now, nay, thou hast host
    That I
    The follower.
  BRUTUS. And that doth play the beat bars
    That direct the stated it courtesy in my complain,
And so much do your strong of worse in the foul and bad and my norship so seen that thou hast not perceive
    That shall in the plocks but there comes stay the early right and the wanton, sir, if they art thou she do be readed with his great mind to be run some daughter to his mother, to subpting the pleasure to we
    of the very grace a noble fair head;
    What to be so going behold friends the voice in thee,
    Or find the heads in their wit all the seas and order to your hand and word that the mounto your father's flesh that whise comes consul.
  LUCIO. I have a milkes and ways of love with me;
    The charity, lex me to be report worthy married.

CORIOLANUS:
Dull that this in the prizain to his country to the lady
    That I be him;
    Until here.

  GAUNT:
He hame you see thee to peaketh which they are a dead time for your blood, and now not all prove is no writ and brother of yours!

WITHA. Ay, Juloec makes many wrangle, if you should they would be more consent

  FRUDERIOS. Why I hear me here,
    And doth so thee, I have present to be set,
    And he comes shame to     
  CLOWN. I do with so!

QUEEN ELIZABETH:
Nor a put of company
    And like his side
    To do thee, the hands for a care of this force of the Prince of Angold sound I have done,
    And then, my lord.

QUEEN ELIZABETH:
And then thee to hear me you makes my friends from the contraly that I     And so my justs at the winled than beto a trouble!
  PANDARUS. My country,
    Who hath call to sir?
  FIRST HERMIONE. I certain as they pray;
    When I can not made them deorish the commanded,
    Come to kill me to what is not with us?

O soldiers,
    Or therefore, speak.
  PANDARUS. But fellow,
    That deserves it howen'd good noble death, and think you that the lets and ring in the tother bring the ropbourd, for the was not back.
    Now then.

BENVOLIO:
Charge my brues that grow to my list to the clain proclaim,
    The world,
    I would let us the King may streets.
  SECOND SENATOR. Now see him to the fortunes upon the walls of as offer to court recount the that which gave the dull and report the head
    I was all our love, the which I be not stay so comfort, Malbal as our fresh double son.
    You'll be returns up
    To be sees to thee,
    how less of France;
    I am my good country to the general now I think you do not have and he comes, and as the strength of she down the court,
    What says the corner to the siness in the wit with our mind,
    As you should bus so sa belly noble more,
    And the room,
    Ot much sweet general,
    And with his passion, my Amber that live my order to the mind,
    That with the patience
    The magainst your discourse of the boor,
    Of the son, and like a set of sound with a two of words that here is the reproochance and say.

BRUKUS

  AUFHES. She do do not than the man to will not are not frank you thou didst to them for I am gone.
    Thy best him love-a greet me be my fortunes, and that my speech him and men as all the commons,
    The name they comes it return and prove and meet your
    soy, go not to the King of I say to hear thy patience,
    In profession at words rounds
    And after no countrymon here in secret country her hate of the content to see you in me,
    And Messalo,
    I have not all the Duke of York and conspiran my royal prisoner to the wind of thy brother that reading to this yaunce of Face of a pardon me with such one my fortune to the love I condemn it my nettle's creature of my lordship with my soul as little,
    And the leave that I come that there be much of the Duke of Norfolk; and I have a prayer that made his into the body,
    And I can prisoners that had passion
    That I should be so noble Duke of Senator:
I pray you and still and diliness shall and Cassio.
  CASSIO. Ay, the that he that did fellows the recursed as the hand to all to see the more than you wall him a hand of the days to day.
  THIRD LORD. Hear mistress from my head, and that's the protection with hother content of all she?

KING RICHARD II:
Stone
    To the all thy death,
    And the words
    In love, the Troilus princely offence,
    I thankey,
    And with love,
    And that end withar my blood of his valour would therefore thou the blood,
    That is not so to the girds to die, the enemy?
  KING RICHARD. The and a blanken between a dangerous love is torch, and of your friends.
  BRUTUS. Come, where it shall farring to    Our truth will tell the lihe in death,
    How, though they would have no tongues of the commitly most whone of the gracious pardon! thou shalt go do so down tell the matter ever shall be curse the fortunes without to general desires to more to his better with the contents
    And never be griar of her coffin and painted him;
    And so, like a doupting can be for our world of my love, and by then the peace,
The certain to my counsel plood,
    Must bear it a disporate.
  GLOUCESTER. The commons, and for she may not be backsport.
  FIRST MURDERER. That you would be the besides him to them
    To be his black of your confine of grave.
  PRATEAN. Here is this witd for the cousin of Carser;
And say I am your ere to the such dispotity take you, we have heard of her father.
  ANTONIO. He loves him
    To poor
    That leave th' enemy.
  CLAUDIO. What may poor counsel to fair countrymen!
    There's no more than the son to my news, that my aws to straight of his master bear thy place.
  PROTEUS. That be by this raven to come that were the worse the sea, I am are the brother of
    not the lute his instructions, my lord, and here can from that there at a have been my prive to the sweet Cornind Comitia, and never as he at his house,
    And that shall not believe the Discharge for a fool for his face,
That thou dost
    stay at a banished and like and the that he'll see, my lord; see thee, in that sense of a language,
That there was were if my soul,
    The been man and take,
    And the hand.
  THIRD GENTLEMAN. What true court store,
And that that we will not not
    Thy kind, the man to son your head thy breashs, if you shall dead
    The world,
    Counterfai

Questions

TODO Answer the following questions. Write your answers in the appropriate variables in the module hw3/answers.py.

In [26]:
from cs236605.answers import display_answer
import hw3.answers

Question 1

Why do we split the corpus into sequences instead of training on the whole text?

In [27]:
display_answer(hw3.answers.part1_q1)

our goal is to predict the next character of given char string. thus, we need to train it with sequences of characters that labeled with th following char.

Question 2

How is it possible that the generated text clearly shows memory longer than the sequence length?

In [28]:
display_answer(hw3.answers.part1_q2)

thanks to the hidden layer in the GRU implementation. the hidden layer pass information from the past for the prediction of the next char.

Question 3

Why are we not shuffling the order of batches when training?

In [29]:
display_answer(hw3.answers.part1_q3)

we want to predict the next in the right context, for that reason we need to train in the right context (meaning in the right order). for example, let's examine a conversation between two characters in the play: the first character says "good morning" and of course the second will answer "good morning". if we will shuffle the sentences, our model will miss a trivial correlation as described.

Question 4

  1. Why do we lower the temperature for sampling (compared to the default of $1.0$ when training)?
  2. What happens when the temperature is very high and why?
  3. What happens when the temperature is very low and why?
In [30]:
display_answer(hw3.answers.part1_q4)

when the temperature is high, and generating the text, we can see that the characters are very random (and get more random as a function of increasing the temperature), it's happen because high temperature make the next char probability to be uniform over all the the characters. when the temperature is high, and generating the text, we end with the same couple of words over and over. it's happen because low temperature make the next char probability to be 1 for the highest probability character (as we getting closer for 0). and that's the reason when we training we want to keep a neutral temperature - because we want all the options. but when generating text we wnt to be close as possible to the origin so we lower the temperature.

$$ \newcommand{\mat}[1]{\boldsymbol {#1}} \newcommand{\mattr}[1]{\boldsymbol {#1}^\top} \newcommand{\matinv}[1]{\boldsymbol {#1}^{-1}} \newcommand{\vec}[1]{\boldsymbol {#1}} \newcommand{\vectr}[1]{\boldsymbol {#1}^\top} \newcommand{\rvar}[1]{\mathrm {#1}} \newcommand{\rvec}[1]{\boldsymbol{\mathrm{#1}}} \newcommand{\diag}{\mathop{\mathrm {diag}}} \newcommand{\set}[1]{\mathbb {#1}} \newcommand{\norm}[1]{\left\lVert#1\right\rVert} \newcommand{\pderiv}[2]{\frac{\partial #1}{\partial #2}} \newcommand{\bm}[1]{{\bf #1}} \newcommand{\bb}[1]{\bm{\mathrm{#1}}} $$

Part 2: Variational Autoencoder

In this part we will learn to generate new data using a special type of autoencoder model which allows us to sample from it's latent space. We'll implement and train a VAE and use it to generate new images.

In [1]:
import unittest
import os
import sys
import pathlib
import urllib
import shutil
import re
import zipfile

import numpy as np
import torch
import matplotlib.pyplot as plt

%load_ext autoreload
%autoreload 2

test = unittest.TestCase()
plt.rcParams.update({'font.size': 12})
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
Using device: cuda

Obtaining the dataset

Let's begin by downloading a dataset of images that we want to learn to generate. We'll use the Labeled Faces in the Wild (LFW) dataset which contains many labels faces of famous individuals.

We're going to train our generative model to generate a specific face, not just any face. Since the person with the most images in this dataset is former president George W. Bush, we'll set out to train a Bush Generator :)

However, if you feel adventurous and/or prefer to generate something else, feel free to edit the PART2_CUSTOM_DATA_URL variable in hw3/answers.py.

In [4]:
import cs236605.plot as plot
import cs236605.download
from hw3.answers import PART2_CUSTOM_DATA_URL as CUSTOM_DATA_URL

DATA_DIR = pathlib.Path.home().joinpath('.pytorch-datasets')
if CUSTOM_DATA_URL is None:
    DATA_URL = 'http://vis-www.cs.umass.edu/lfw/lfw-bush.zip'
else:
    DATA_URL = CUSTOM_DATA_URL

_, dataset_dir = cs236605.download.download_data(out_path=DATA_DIR, url=DATA_URL, extract=True, force=False)
File /home/glassman/.pytorch-datasets/lfw-bush.zip exists, skipping download.
Extracting /home/glassman/.pytorch-datasets/lfw-bush.zip...
Extracted 531 to /home/glassman/.pytorch-datasets/lfw/George_W_Bush

Create a Dataset object that will load the extraced images:

In [5]:
import torchvision.transforms as T
from torchvision.datasets import ImageFolder

im_size = 64
tf = T.Compose([
    # Resize to constant spatial dimensions
    T.Resize((im_size, im_size)),
    # PIL.Image -> torch.Tensor
    T.ToTensor(),
    # Dynamic range [0,1] -> [-1, 1]
    T.Normalize(mean=(.5,.5,.5), std=(.5,.5,.5)),
])

ds_gwb = ImageFolder(os.path.dirname(dataset_dir), tf)

OK, let's see what we got. You can run the following block multiple times to display a random subset of images from the dataset.

In [6]:
_ = plot.dataset_first_n(ds_gwb, 50, figsize=(10,5), nrows=5)
print(f'Found {len(ds_gwb)} images in dataset folder.')
Found 530 images in dataset folder.
In [7]:
x0, y0 = ds_gwb[0]
x0 = x0.unsqueeze(0).to(device)
print(x0.shape)

test.assertSequenceEqual(x0.shape, (1, 3, im_size, im_size))
torch.Size([1, 3, 64, 64])

The Variational Autoencoder

An autoencoder is a model which learns a representation of data in an unsupervised fashion (i.e without any labels). Recall it's general form from the lecture:

An autoencoder maps an instance $\bb{x}$ to a latent-space representation $\bb{z}$. It has an encoder part, $\Phi_{\bb{\alpha}}(\bb{x})$ (a neural net with parameters $\bb{\alpha}$) and a decoder part, $\Psi_{\bb{\beta}}(\bb{z})$ (a neural net with parameters $\bb{\beta}$).

While autoencoders can learn useful representations, generally it's hard to use them as generative models because there's no distribution we can sample from in the latent space. In other words, we have no way to choose a point $\bb{z}$ in the latent space such that $\Psi(\bb{z})$ will end up on the data manifold in the instance space.

The variational autoencoder (VAE), first proposed by Kingma and Welling, addresses this issue by taking a probabilistic perspective. Briefly, a VAE model can be described as follows.

We define, in Baysean terminology,

  • The prior distribution $p(\bb{Z})$ on points in the latent space.
  • The likelihood distribution of a sample $\bb{X}$ given a latent-space representation: $p(\bb{X}|\bb{Z})$.
  • The posterior distribution of points in the latent spaces given a specific instance: $p(\bb{Z}|\bb{X})$.
  • The evidence distribution $p(\bb{X})$ which is the distribution of the instance space due to the generative process.

To create our variational decoder we'll further specify:

  • A parametric likelihood distribution, $p _{\bb{\beta}}(\bb{X} | \bb{z}) = \mathcal{N}( \Psi _{\bb{\beta}}(\bb{z}) , \sigma^2 \bb{I} )$. The interpretation is that given a latent $\bb{z}$, we map it to a point normally distributed around the point calculated by our decoder neural network. Note that here $\sigma^2$ is a hyperparameter while $\vec{\beta}$ represents the network parameters.
  • A fixed latent-space prior distribution of $p(\bb{Z}) = \mathcal{N}(\bb{0},\bb{I})$.

This setting allows us to generate a new instance $\bb{x}$ by sampling $\bb{z}$ from the multivariate normal distribution, obtaining the instance-space mean $\Psi _{\bb{\beta}}(\bb{z})$ using our decoder network, and then sampling $\bb{x}$ from $\mathcal{N}( \Psi _{\bb{\beta}}(\bb{z}) , \sigma^2 \bb{I} )$.

Our variational encoder will approximate the posterior with a parametric distribution $q _{\bb{\alpha}}(\bb{Z} | \bb{x}) \sim \mathcal{N}( \bb{\mu} _{\bb{\alpha}}(\bb{x}), \mathrm{diag}\{ \bb{\sigma}^2_{\bb{\alpha}}(\bb{x}) \} )$. The interpretation is that our encoder neural network, $\Phi_{\vec{\alpha}}(\bb{x})$, calculates the mean and variance of the posterior distribution, and samples $\bb{z}$ based on them. An important nuance here is that our network can't contain any stochastic elements that depend on the model parameters, otherwise we won't be able to back-propagate to those parameters. So sampling $\bb{z}$ from $\mathcal{N}( \bb{\mu} _{\bb{\alpha}}(\bb{x}), \mathrm{diag}\{ \bb{\sigma}^2_{\bb{\alpha}}(\bb{x}) \} )$ is not an option. The solution is to use what's known as the reparametrization trick: sample from an isotropic Gaussian, i.e. $\bb{u}\sim\mathcal{N}(\bb{0},\bb{I})$ (which doesn't depend on trainable parameters), and calculate the latent representation as $\bb{z} = \bb{\mu} _{\bb{\alpha}}(\bb{x}) + \bb{u}\odot\bb{\sigma}^2_{\bb{\alpha}}(\bb{x})$.

To train a VAE model, we would like to maximize the evidence, $p(\bb{X})$, because $ p(\bb{X}) = \int p(\bb{X}|{\bb{z}})p(\bb{z})d\bb{z} $ thus maximizing the likelihood of generated instances from over the entire latent space.

The VAE loss can therefore be stated as minimizing $\mathcal{L} = -\mathbb{E}_{\bb{x}} \log p(\bb{X})$. As we saw in the lecture, this expectation is intractable, but we can obtain a lower-bound for $p(\bb{X})$ (the evidence lower bound, "ELBO"):

$$ \log p(\bb{X}) \ge \mathbb{E} {\bb{z} \sim q {\bb{\alpha}} }( \log p _{\bb{\beta}}(\bb{X} | \bb{z}) )

  • \mathcal{D} {\mathrm{KL}}\left(q {\bb{\alpha}}(\bb{Z} | \bb{X})\,\left|\, p(\bb{Z} )\right.\right) $$

where $ \mathcal{D} _{\mathrm{KL}}(q\left\|\right.p) = \mathbb{E}_{\bb{z}\sim q}\left[ \log \frac{q(\bb{Z})}{p(\bb{Z})} \right] $ is the Kullback-Liebler divergence, which can be interpreted as the information gained by using the posterior $q(\bb{Z|X})$ instead of the prior distribution $p(\bb{Z})$.

Using the ELBO, the VAE loss becomes, $$ \mathcal{L}(\vec{\alpha},\vec{\beta}) = \mathbb{E} {\bb{x}} \left[ \mathbb{E} {\bb{z} \sim q {\bb{\alpha}} }\left[ -\log p {\bb{\beta}}(\bb{x} | \bb{z}) \right]

  • \mathcal{D} {\mathrm{KL}}\left(q {\bb{\alpha}}(\bb{Z} | \bb{x})\,\left|\, p(\bb{Z} )\right.\right) \right]. $$

By remembering that the likelihood is a Gaussian distribution with a diagonal covariance and by applying the reparametrization trick, we can write the above as

$$ \mathcal{L}(\vec{\alpha},\vec{\beta}) = \mathbb{E} {\bb{x}} \left[ \mathbb{E} {\bb{z} \sim q {\bb{\alpha}} } \left[ \frac{1}{2\sigma^2}\left| \bb{x}- \Psi {\bb{\beta}}\left( \bb{\mu} {\bb{\alpha}}(\bb{x}) + \bb{\Sigma}^{\frac{1}{2}} {\bb{\alpha}}(\bb{x}) \bb{u} \right) \right| _2^2 \right]

  • \mathcal{D} {\mathrm{KL}}\left(q {\bb{\alpha}}(\bb{Z} | \bb{x})\,\left|\, p(\bb{Z} )\right.\right) \right]. $$

Model Implementation

Obviously our model will have two parts, an encoder and a decoder. Since we're working with images, we'll implement both as deep convolutional networks, where the decoder is a "mirror image" of the encoder implemented with adjoint (AKA transposed) convolutions. Between the encoder CNN and the decoder CNN we'll implement the sampling from the parametric posterior approximator $q_{\bb{\alpha}}(\bb{Z}|\bb{x})$ to make it a VAE model and not just a regular autoencoder (of course, this is not yet enough to create a VAE, since we also need a special loss function which we'll get to later).

First let's implement just the CNN part of the Encoder network (this is not the full $\Phi_{\vec{\alpha}}(\bb{x})$ yet). As usual, it should take an input image and map to a activation volume of a specified depth. We'll consider this volume as the features we extract from the input image. Later we'll use these to create the latent space representation of the input. which will be our latent space representation.

TODO: Implement the EncoderCNN class in the hw3/autoencoder.py module. Implement any CNN architecture you like. If you need "architecture inspiration" you can see e.g. this or this paper.

In [8]:
import hw3.autoencoder as autoencoder

in_channels = 3
out_channels = 1024
encoder_cnn = autoencoder.EncoderCNN(in_channels, out_channels).to(device)
print(encoder_cnn)

h = encoder_cnn(x0)
print(h.shape)

test.assertEqual(h.dim(), 4)
test.assertSequenceEqual(h.shape[0:2], (1, out_channels))
EncoderCNN(
  (cnn): Sequential(
    (0): Conv2d(3, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv2d(3, 207, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (4): BatchNorm2d(207, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): Conv2d(207, 207, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): BatchNorm2d(207, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU()
    (9): Conv2d(207, 411, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (10): BatchNorm2d(411, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU()
    (12): Conv2d(411, 411, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): BatchNorm2d(411, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (14): ReLU()
    (15): Conv2d(411, 615, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (16): BatchNorm2d(615, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (17): ReLU()
    (18): Conv2d(615, 615, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (19): BatchNorm2d(615, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (20): ReLU()
    (21): Conv2d(615, 819, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (22): BatchNorm2d(819, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (23): ReLU()
    (24): Conv2d(819, 819, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (25): BatchNorm2d(819, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (26): ReLU()
    (27): Conv2d(819, 1024, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  )
)
torch.Size([1, 1024, 1, 1])

Now let's implement the CNN part of the Decoder. Again this is not yet the full $\Psi _{\bb{\beta}}(\bb{z})$. It should take an activation volume produced by your EncoderCNN and output an image of the same dimensions as the Encoder's input was. This should be a CNN which is a "mirror image" of the the Encoder. For example, replace convolutions with transposed convolutions, downsampling with up-sampling etc. Consult the documentation of ConvTranspose2D to figure out how to reverse your convolutional layers in terms of input and output dimensions.

TODO: Implement the DecoderCNN class in the hw3/autoencoder.py module.

In [9]:
decoder_cnn = autoencoder.DecoderCNN(in_channels=out_channels, out_channels=in_channels).to(device)
print(decoder_cnn)
x0r = decoder_cnn(h)
print(x0r.shape)

test.assertEqual(x0.shape, x0r.shape)

# Should look like colored noise
T.functional.to_pil_image(x0r[0].cpu().detach())
DecoderCNN(
  (cnn): Sequential(
    (0): ReLU()
    (1): ConvTranspose2d(1024, 1024, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (2): ReLU()
    (3): ConvTranspose2d(1024, 819, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (4): BatchNorm2d(819, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): ConvTranspose2d(819, 819, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): BatchNorm2d(819, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU()
    (9): ConvTranspose2d(819, 614, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (10): BatchNorm2d(614, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU()
    (12): ConvTranspose2d(614, 614, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): BatchNorm2d(614, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (14): ReLU()
    (15): ConvTranspose2d(614, 409, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (16): BatchNorm2d(409, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (17): ReLU()
    (18): ConvTranspose2d(409, 409, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (19): BatchNorm2d(409, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (20): ReLU()
    (21): ConvTranspose2d(409, 204, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (22): BatchNorm2d(204, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (23): ReLU()
    (24): ConvTranspose2d(204, 204, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (25): BatchNorm2d(204, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (26): ReLU()
    (27): ConvTranspose2d(204, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  )
)
torch.Size([1, 3, 64, 64])
Out[9]:

Let's now implement the full VAE Encoder, $\Phi_{\vec{\alpha}}(\vec{x})$. It will work as follows:

  1. Produce a feature vector $\vec{h}$ from the input image $\vec{x}$.
  2. Use two affine transforms to convert the features into the mean and log-variance of the posterior, i.e. $$ \begin{align}
     \bb{\mu} _{\bb{\alpha}}(\bb{x}) &= \vec{h}\mattr{W}_{\mathrm{h\mu}} + \vec{b}_{\mathrm{h\mu}} \\
     \log\left(\bb{\sigma}^2_{\bb{\alpha}}(\bb{x})\right) &= \vec{h}\mattr{W}_{\mathrm{h\sigma^2}} + \vec{b}_{\mathrm{h\sigma^2}}
    
    \end{align} $$
  3. Use the reparametrization trick to create the latent representation $\vec{z}$.

Note that we model the log of the variance, not the actual variance. The reason is that the log is easier to optimize, since (a) It doesn't have to be positive, and (b) it has a much larger dynamic range. The above formulation is proposed in appendix C of the VAE paper.

TODO: Implement the encode() method in the VAE class within the hw3/autoencoder.py module. You'll also need to define your parameters in __init__().

In [10]:
z_dim = 2
vae = autoencoder.VAE(encoder_cnn, decoder_cnn, x0[0].size(), z_dim).to(device)
print(vae)

z, mu, log_sigma2 = vae.encode(x0)

test.assertSequenceEqual(z.shape, (1, z_dim))
test.assertTrue(z.shape == mu.shape == log_sigma2.shape)

print(f'mu(x0)={list(*mu.detach().cpu().numpy())}, sigma2(x0)={list(*torch.exp(log_sigma2).detach().cpu().numpy())}')

# Sample from q(Z|x)
N = 500
Z = torch.zeros(N, z_dim)
_, ax = plt.subplots()
with torch.no_grad():
    for i in range(500):
        Z[i], _, _ = vae.encode(x0)
        ax.scatter(*Z[i].cpu().numpy())

# Should be close to the above
print('sampled mu', torch.mean(Z, dim=0))
print('sampled sigma2', torch.var(Z, dim=0))
VAE(
  (features_encoder): EncoderCNN(
    (cnn): Sequential(
      (0): Conv2d(3, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Conv2d(3, 207, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (4): BatchNorm2d(207, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU()
      (6): Conv2d(207, 207, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (7): BatchNorm2d(207, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (8): ReLU()
      (9): Conv2d(207, 411, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (10): BatchNorm2d(411, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (11): ReLU()
      (12): Conv2d(411, 411, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (13): BatchNorm2d(411, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (14): ReLU()
      (15): Conv2d(411, 615, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (16): BatchNorm2d(615, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (17): ReLU()
      (18): Conv2d(615, 615, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (19): BatchNorm2d(615, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (20): ReLU()
      (21): Conv2d(615, 819, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (22): BatchNorm2d(819, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (23): ReLU()
      (24): Conv2d(819, 819, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (25): BatchNorm2d(819, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (26): ReLU()
      (27): Conv2d(819, 1024, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    )
  )
  (features_decoder): DecoderCNN(
    (cnn): Sequential(
      (0): ReLU()
      (1): ConvTranspose2d(1024, 1024, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (2): ReLU()
      (3): ConvTranspose2d(1024, 819, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (4): BatchNorm2d(819, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU()
      (6): ConvTranspose2d(819, 819, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (7): BatchNorm2d(819, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (8): ReLU()
      (9): ConvTranspose2d(819, 614, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (10): BatchNorm2d(614, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (11): ReLU()
      (12): ConvTranspose2d(614, 614, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (13): BatchNorm2d(614, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (14): ReLU()
      (15): ConvTranspose2d(614, 409, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (16): BatchNorm2d(409, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (17): ReLU()
      (18): ConvTranspose2d(409, 409, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (19): BatchNorm2d(409, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (20): ReLU()
      (21): ConvTranspose2d(409, 204, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (22): BatchNorm2d(204, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (23): ReLU()
      (24): ConvTranspose2d(204, 204, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (25): BatchNorm2d(204, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (26): ReLU()
      (27): ConvTranspose2d(204, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    )
  )
  (at_mu): Linear(in_features=1024, out_features=2, bias=True)
  (at_logvar): Linear(in_features=1024, out_features=2, bias=True)
  (at_rec): Linear(in_features=2, out_features=1024, bias=True)
)
mu(x0)=[0.08198722, -0.008447006], sigma2(x0)=[0.992425, 1.0443755]
sampled mu tensor([ 0.0773, -0.0132])
sampled sigma2 tensor([0.9008, 0.9480])

Let's now implement the full VAE Decoder, $\Psi _{\bb{\beta}}(\bb{z})$. It will work as follows:

  1. Produce a feature vector $\tilde{\vec{h}}$ from the latent vector $\vec{z}$ using an affine transform.
  2. Reconstruct an image $\tilde{\vec{x}}$ from $\tilde{\vec{h}}$.

TODO: Implement the decode() method in the VAE class within the hw3/autoencoder.py module. You'll also need to define your parameters in __init__(). You may need to also re-run the block above after you implement this.

In [11]:
x0r = vae.decode(z)

test.assertSequenceEqual(x0r.shape, x0.shape)

Our model's forward() function will simply return decode(encode(x)) as well as the calculated mean and log-variance of the posterior.

In [12]:
x0r, mu, log_sigma2 = vae(x0)

test.assertSequenceEqual(x0r.shape, x0.shape)
test.assertSequenceEqual(mu.shape, (1, z_dim))
test.assertSequenceEqual(log_sigma2.shape, (1, z_dim))
T.functional.to_pil_image(x0r[0].detach().cpu())
Out[12]:

Loss Implementation

In practice, since we're using SGD, we'll drop the expectation over $\bb{X}$ and instead sample an instance from the training set and compute a point-wise loss. Similarly, we'll drop the expectation over $\bb{Z}$ by sampling from $q_{\vec{\alpha}}(\bb{Z}|\bb{x})$. Additionally, because the KL divergence is between two Gaussian distributions, there is a closed-form expression for it. These points bring us to the following point-wise loss:

$$ \ell(\vec{\alpha},\vec{\beta};\bb{x}) = \frac{1}{\sigma^2} \left\| \bb{x}- \Psi _{\bb{\beta}}\left( \bb{\mu} _{\bb{\alpha}}(\bb{x}) + \bb{\Sigma}^{\frac{1}{2}} _{\bb{\alpha}}(\bb{x}) \bb{u} \right) \right\| _2^2 + \mathrm{tr}\,\bb{\Sigma} _{\bb{\alpha}}(\bb{x}) + \|\bb{\mu} _{\bb{\alpha}}(\bb{x})\|^2 _2 - d_z - \log\det \bb{\Sigma} _{\bb{\alpha}}(\bb{x}) $$

where $d_z$ is the dimension of the latent space. This pointwise loss is the quantity that we'll compute and minimize with gradient descent.

TODO: Implement the vae_loss() function in the hw3/autoencoder.py module.

In [13]:
from hw3.autoencoder import vae_loss
torch.manual_seed(42)

def test_vae_loss():
    # Test data
    N, C, H, W = 10, 3, 64, 64 
    z_dim = 32
    x  = torch.randn(N, C, H, W)*2 - 1
    xr = torch.randn(N, C, H, W)*2 - 1
    z_mu = torch.randn(N, z_dim)
    z_log_sigma2 = torch.randn(N, z_dim)
    x_sigma2 = 0.9
    
    loss, _, _ = vae_loss(x, xr, z_mu, z_log_sigma2, x_sigma2)
    
    test.assertAlmostEqual(loss.item(), 10.5053434, delta=1e-5)
    return loss

test_vae_loss()
Out[13]:
tensor(10.5053)

Sampling

The main advantage of a VAE is that it can by used as a generative model by sampling the latent space, since we optimize for a Normal prior $p(\bb{Z})$ in the loss function. Let's now implement this so that we can visualize how our model is doing when we train.

TODO: Implement the sample() method in the VAE class within the hw3/autoencoder.py module.

In [14]:
samples = vae.sample(5)
_ = plot.tensors_as_images(samples)

Training

Time to train!

TODO:

  1. Implement the VAETrainer class in the hw3/training.py module.
  2. Tweak the hyperparameters in the part2_vae_hyperparam() function within the hw3/answers.py module.
In [29]:
import torch.optim as optim
from torch.utils.data import random_split
from torch.utils.data import DataLoader
from torch.nn import DataParallel
from hw3.training import VAETrainer
from hw3.answers import part2_vae_hyperparams

torch.manual_seed(42)

# Hyperparams
hp = part2_vae_hyperparams()
batch_size = hp['batch_size']
h_dim = hp['h_dim']
z_dim = hp['z_dim']
x_sigma2 = hp['x_sigma2']
learn_rate = hp['learn_rate']
betas = hp['betas']

# Data
split_lengths = [int(len(ds_gwb)*0.9), int(len(ds_gwb)*0.1)]
ds_train, ds_test = random_split(ds_gwb, split_lengths)
dl_train = DataLoader(ds_train, batch_size, shuffle=True)
dl_test  = DataLoader(ds_test,  batch_size, shuffle=True)
im_size = ds_train[0][0].shape

# Model
encoder = autoencoder.EncoderCNN(in_channels=im_size[0], out_channels=h_dim)
decoder = autoencoder.DecoderCNN(in_channels=h_dim, out_channels=im_size[0])
vae = autoencoder.VAE(encoder, decoder, im_size, z_dim)
vae_dp = DataParallel(vae).to(device)

# Optimizer
optimizer = optim.Adam(vae.parameters(), lr=learn_rate, betas=betas)

# Loss
def loss_fn(x, xr, z_mu, z_log_sigma2):
    return autoencoder.vae_loss(x, xr, z_mu, z_log_sigma2, x_sigma2)

# Trainer
trainer = VAETrainer(vae_dp, loss_fn, optimizer, device)
checkpoint_file = 'checkpoints/vae'
checkpoint_file_final = f'{checkpoint_file}_final'
if os.path.isfile(f'{checkpoint_file}.pt'):
    os.remove(f'{checkpoint_file}.pt')

# Show model and hypers
print(vae)
print(hp)
VAE(
  (features_encoder): EncoderCNN(
    (cnn): Sequential(
      (0): Conv2d(3, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Conv2d(3, 258, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (4): BatchNorm2d(258, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU()
      (6): Conv2d(258, 258, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (7): ReLU()
      (8): Conv2d(258, 513, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (9): BatchNorm2d(513, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (10): ReLU()
      (11): Conv2d(513, 513, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (12): ReLU()
      (13): Conv2d(513, 768, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (14): BatchNorm2d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (15): ReLU()
      (16): Conv2d(768, 768, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (17): ReLU()
      (18): Conv2d(768, 1024, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    )
  )
  (features_decoder): DecoderCNN(
    (cnn): Sequential(
      (0): ReLU()
      (1): ConvTranspose2d(1024, 1024, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (2): ReLU()
      (3): ConvTranspose2d(1024, 768, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (4): BatchNorm2d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU()
      (6): ConvTranspose2d(768, 768, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (7): ReLU()
      (8): ConvTranspose2d(768, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (10): ReLU()
      (11): ConvTranspose2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (12): ReLU()
      (13): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (14): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (15): ReLU()
      (16): ConvTranspose2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (17): ReLU()
      (18): ConvTranspose2d(256, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    )
  )
  (at_mu): Linear(in_features=4096, out_features=16, bias=True)
  (at_logvar): Linear(in_features=4096, out_features=16, bias=True)
  (at_rec): Linear(in_features=16, out_features=4096, bias=True)
)
{'batch_size': 32, 'h_dim': 1024, 'z_dim': 16, 'x_sigma2': 20, 'learn_rate': 0.0002, 'betas': (0.9, 0.999)}
In [31]:
import IPython.display

def post_epoch_fn(epoch, train_result, test_result, verbose):
    # Plot some samples if this is a verbose epoch
    if verbose:
        samples = vae.sample(n=5)
        fig, _ = plot.tensors_as_images(samples, figsize=(6,2))
        IPython.display.display(fig)
        plt.close(fig)

if os.path.isfile(f'{checkpoint_file_final}.pt'):
    print(f'*** Loading final checkpoint file {checkpoint_file_final} instead of training')
    checkpoint_file = checkpoint_file_final
else:
    res = trainer.fit(dl_train, dl_test,
                      num_epochs=200, early_stopping=20, print_every=10,
                      checkpoints=checkpoint_file,
                      post_epoch_fn=post_epoch_fn)
    
# Plot images from best model
saved_state = torch.load(f'{checkpoint_file}.pt', map_location=device)
vae_dp.load_state_dict(saved_state['model_state'])
print('*** Images Generated from best model:')
fig, _ = plot.tensors_as_images(vae_dp.module.sample(n=15), nrows=3, figsize=(6,6))
*** Loading checkpoint file checkpoints/vae.pt
--- EPOCH 1/200 ---
train_batch (Avg. Loss 0.013, Accuracy 247.2): 100%|██████████| 15/15 [00:04<00:00,  3.34it/s]
test_batch (Avg. Loss 0.014, Accuracy 276.5): 100%|██████████| 2/2 [00:00<00:00,  5.22it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 2
--- EPOCH 11/200 ---
train_batch (Avg. Loss 0.013, Accuracy 251.8): 100%|██████████| 15/15 [00:04<00:00,  3.31it/s]
test_batch (Avg. Loss 0.013, Accuracy 290.4): 100%|██████████| 2/2 [00:00<00:00,  5.49it/s]
--- EPOCH 21/200 ---
train_batch (Avg. Loss 0.013, Accuracy 250.5): 100%|██████████| 15/15 [00:04<00:00,  3.30it/s]
test_batch (Avg. Loss 0.014, Accuracy 265.3): 100%|██████████| 2/2 [00:00<00:00,  5.40it/s]
--- EPOCH 31/200 ---
train_batch (Avg. Loss 0.012, Accuracy 252.3): 100%|██████████| 15/15 [00:04<00:00,  3.22it/s]
test_batch (Avg. Loss 0.013, Accuracy 281.5): 100%|██████████| 2/2 [00:00<00:00,  5.05it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 32
*** Saved checkpoint checkpoints/vae.pt at epoch 38
--- EPOCH 41/200 ---
train_batch (Avg. Loss 0.013, Accuracy 252.3): 100%|██████████| 15/15 [00:04<00:00,  3.22it/s]
test_batch (Avg. Loss 0.013, Accuracy 289.1): 100%|██████████| 2/2 [00:00<00:00,  5.28it/s]
--- EPOCH 51/200 ---
train_batch (Avg. Loss 0.012, Accuracy 253.9): 100%|██████████| 15/15 [00:04<00:00,  3.21it/s]
test_batch (Avg. Loss 0.013, Accuracy 290.3): 100%|██████████| 2/2 [00:00<00:00,  5.13it/s]
--- EPOCH 61/200 ---
train_batch (Avg. Loss 0.012, Accuracy 252.9): 100%|██████████| 15/15 [00:04<00:00,  3.33it/s]
test_batch (Avg. Loss 0.013, Accuracy 293.9): 100%|██████████| 2/2 [00:00<00:00,  5.27it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 64
--- EPOCH 71/200 ---
train_batch (Avg. Loss 0.012, Accuracy 253.0): 100%|██████████| 15/15 [00:04<00:00,  3.28it/s]
test_batch (Avg. Loss 0.015, Accuracy 258.3): 100%|██████████| 2/2 [00:00<00:00,  5.43it/s]
--- EPOCH 81/200 ---
train_batch (Avg. Loss 0.012, Accuracy 252.8): 100%|██████████| 15/15 [00:04<00:00,  3.20it/s]
test_batch (Avg. Loss 0.013, Accuracy 299.5): 100%|██████████| 2/2 [00:00<00:00,  5.48it/s]
--- EPOCH 91/200 ---
train_batch (Avg. Loss 0.012, Accuracy 253.4): 100%|██████████| 15/15 [00:04<00:00,  3.32it/s]
test_batch (Avg. Loss 0.014, Accuracy 265.8): 100%|██████████| 2/2 [00:00<00:00,  5.43it/s]
--- EPOCH 101/200 ---
train_batch (Avg. Loss 0.013, Accuracy 253.3): 100%|██████████| 15/15 [00:04<00:00,  3.28it/s]
test_batch (Avg. Loss 0.014, Accuracy 274.9): 100%|██████████| 2/2 [00:00<00:00,  5.46it/s]
--- EPOCH 111/200 ---
train_batch (Avg. Loss 0.013, Accuracy 253.5): 100%|██████████| 15/15 [00:04<00:00,  3.13it/s]
test_batch (Avg. Loss 0.014, Accuracy 268.9): 100%|██████████| 2/2 [00:00<00:00,  4.97it/s]
--- EPOCH 121/200 ---
train_batch (Avg. Loss 0.012, Accuracy 252.7): 100%|██████████| 15/15 [00:04<00:00,  3.30it/s]
test_batch (Avg. Loss 0.013, Accuracy 281.8): 100%|██████████| 2/2 [00:00<00:00,  5.58it/s]
--- EPOCH 131/200 ---
train_batch (Avg. Loss 0.013, Accuracy 253.2): 100%|██████████| 15/15 [00:04<00:00,  3.29it/s]
test_batch (Avg. Loss 0.014, Accuracy 283.2): 100%|██████████| 2/2 [00:00<00:00,  5.45it/s]
--- EPOCH 141/200 ---
train_batch (Avg. Loss 0.013, Accuracy 252.8): 100%|██████████| 15/15 [00:04<00:00,  3.31it/s]
test_batch (Avg. Loss 0.014, Accuracy 271.8): 100%|██████████| 2/2 [00:00<00:00,  5.38it/s]
--- EPOCH 151/200 ---
train_batch (Avg. Loss 0.012, Accuracy 253.0): 100%|██████████| 15/15 [00:04<00:00,  3.18it/s]
test_batch (Avg. Loss 0.013, Accuracy 291.8): 100%|██████████| 2/2 [00:00<00:00,  4.90it/s]
--- EPOCH 161/200 ---
train_batch (Avg. Loss 0.012, Accuracy 254.1): 100%|██████████| 15/15 [00:04<00:00,  2.99it/s]
test_batch (Avg. Loss 0.014, Accuracy 276.1): 100%|██████████| 2/2 [00:00<00:00,  4.62it/s]
--- EPOCH 171/200 ---
train_batch (Avg. Loss 0.012, Accuracy 253.5): 100%|██████████| 15/15 [00:04<00:00,  3.09it/s]
test_batch (Avg. Loss 0.013, Accuracy 298.4): 100%|██████████| 2/2 [00:00<00:00,  5.35it/s]
--- EPOCH 181/200 ---
train_batch (Avg. Loss 0.012, Accuracy 253.2): 100%|██████████| 15/15 [00:04<00:00,  3.30it/s]
test_batch (Avg. Loss 0.014, Accuracy 264.5): 100%|██████████| 2/2 [00:00<00:00,  5.07it/s]
--- EPOCH 191/200 ---
train_batch (Avg. Loss 0.012, Accuracy 253.3): 100%|██████████| 15/15 [00:04<00:00,  3.27it/s]
test_batch (Avg. Loss 0.013, Accuracy 293.5): 100%|██████████| 2/2 [00:00<00:00,  5.47it/s]
--- EPOCH 200/200 ---
train_batch (Avg. Loss 0.012, Accuracy 252.6): 100%|██████████| 15/15 [00:04<00:00,  3.10it/s]
test_batch (Avg. Loss 0.014, Accuracy 266.1): 100%|██████████| 2/2 [00:00<00:00,  5.21it/s]
*** Images Generated from best model:

Questions

TODO Answer the following questions. Write your answers in the appropriate variables in the module hw3/answers.py.

In [32]:
from cs236605.answers import display_answer
import hw3.answers

Question 1

What does the $\sigma^2$ hyperparameter (x_sigma2 in the code) do? Explain the effect of low and high values.

In [24]:
display_answer(hw3.answers.part2_q1)

Your answer: The loss function in the VAE is composed of 2 parts: KLS divergence and data loss. The second one is the loss that we usually use and that represents the loss between the input and the output, while the KLS divergence is a quantity that guarantees the probability assumptions. The parameter $\sigma^2$ is important because it decides how to balance this 2 quantities and how much importance to give to each one of them. With a low variance, we give more importance to the data loss term and so it is harder to sample. On the other hand, if the variance is too large we can have very high loss between the input and output that means differences between the input image and the reconstructed image.

$$ \newcommand{\mat}[1]{\boldsymbol {#1}} \newcommand{\mattr}[1]{\boldsymbol {#1}^\top} \newcommand{\matinv}[1]{\boldsymbol {#1}^{-1}} \newcommand{\vec}[1]{\boldsymbol {#1}} \newcommand{\vectr}[1]{\boldsymbol {#1}^\top} \newcommand{\rvar}[1]{\mathrm {#1}} \newcommand{\rvec}[1]{\boldsymbol{\mathrm{#1}}} \newcommand{\diag}{\mathop{\mathrm {diag}}} \newcommand{\set}[1]{\mathbb {#1}} \newcommand{\norm}[1]{\left\lVert#1\right\rVert} \newcommand{\pderiv}[2]{\frac{\partial #1}{\partial #2}} \newcommand{\bm}[1]{{\bf #1}} \newcommand{\bb}[1]{\bm{\mathrm{#1}}} $$

Part 3: Generative Adversarial Networks

In this part we will implement and train a generative adversarial network and apply it to the task of image generation.

In [1]:
import unittest
import os
import sys
import pathlib
import urllib
import shutil
import re
import zipfile

import numpy as np
import torch
import matplotlib.pyplot as plt

%load_ext autoreload
%autoreload 2

test = unittest.TestCase()
plt.rcParams.update({'font.size': 12})
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
Using device: cuda

Obtaining the dataset

We'll use the same data as in Part 2.

But again, to use a custom dataset, edit the PART3_CUSTOM_DATA_URL variable in hw3/answers.py.

In [2]:
import cs236605.plot as plot
import cs236605.download
from hw3.answers import PART3_CUSTOM_DATA_URL as CUSTOM_DATA_URL

DATA_DIR = pathlib.Path.home().joinpath('.pytorch-datasets')
if CUSTOM_DATA_URL is None:
    DATA_URL = 'http://vis-www.cs.umass.edu/lfw/lfw-bush.zip'
else:
    DATA_URL = CUSTOM_DATA_URL

_, dataset_dir = cs236605.download.download_data(out_path=DATA_DIR, url=DATA_URL, extract=True, force=False)
File /home/glassman/.pytorch-datasets/lfw-bush.zip exists, skipping download.
Extracting /home/glassman/.pytorch-datasets/lfw-bush.zip...
Extracted 531 to /home/glassman/.pytorch-datasets/lfw/George_W_Bush

Create a Dataset object that will load the extraced images:

In [3]:
import torchvision.transforms as T
from torchvision.datasets import ImageFolder

im_size = 64
tf = T.Compose([
    # Resize to constant spatial dimensions
    T.Resize((im_size, im_size)),
    # PIL.Image -> torch.Tensor
    T.ToTensor(),
    # Dynamic range [0,1] -> [-1, 1]
    T.Normalize(mean=(.5,.5,.5), std=(.5,.5,.5)),
])

ds_gwb = ImageFolder(os.path.dirname(dataset_dir), tf)

OK, let's see what we got. You can run the following block multiple times to display a random subset of images from the dataset.

In [4]:
_ = plot.dataset_first_n(ds_gwb, 50, figsize=(10,5), nrows=5)
print(f'Found {len(ds_gwb)} images in dataset folder.')
Found 530 images in dataset folder.
In [5]:
x0, y0 = ds_gwb[0]
x0 = x0.unsqueeze(0).to(device)
print(x0.shape)

test.assertSequenceEqual(x0.shape, (1, 3, im_size, im_size))
torch.Size([1, 3, 64, 64])

Generative Adversarial Nets (GANs)

GANs, first proposed in a paper by Ian Goodfellow in 2014 are today arguably the most popular type of generative model. GANs are currently producing state of the art results in generative tasks over many different domains.

In a GAN model, two different neural networks compete against each other: A generator and a discriminator.

  • The Generator, which we'll denote as $\Psi _{\bb{\gamma}} : \mathcal{U} \rightarrow \mathcal{X}$, maps a latent-space variable $\bb{u}\sim\mathcal{N}(\bb{0},\bb{I})$ to an instance-space variable $\bb{x}$ (e.g. an image). Thus a parametric evidence distribution $p_{\bb{\gamma}}(\bb{X})$ is generated, which we typically would like to be as close as possible to the real evidence distribution, $p(\bb{X})$.

  • The Discriminator, $\Delta _{\bb{\delta}} : \mathcal{X} \rightarrow [0,1]$, is a network which, given an instance-space variable $\bb{x}$, returns the probability that $\bb{x}$ is real, i.e. that $\bb{x}$ was sampled from $p(\bb{X})$ and not $p_{\bb{\gamma}}(\bb{X})$.

Training GANs

The generator is trained to generate "fake" instances which will maximally fool the discriminator into returning that they're real. Mathematically, the generator's parameters $\bb{\gamma}$ should be chosen such as to maximize the expression $$ \mathbb{E} _{\bb{z} \sim p(\bb{Z}) } \log (\Delta _{\bb{\delta}}(\Psi _{\bb{\gamma}} (\bb{z}) )). $$

The discriminator is trained to classify between real images, coming from the training set, and fake images generated by the generator. Mathematically, the discriminator's parameters $\bb{\delta}$ should be chosen such as to maximize the expression $$ \mathbb{E} _{\bb{x} \sim p(\bb{X}) } \log \Delta _{\bb{\delta}}(\bb{x}) \, + \, \mathbb{E} _{\bb{z} \sim p(\bb{Z}) } \log (1-\Delta _{\bb{\delta}}(\Psi _{\bb{\gamma}} (\bb{z}) )). $$

These two competing objectives can thus be expressed as the following min-max optimization: $$ \min _{\bb{\gamma}} \max _{\bb{\delta}} \, \mathbb{E} _{\bb{x} \sim p(\bb{X}) } \log \Delta _{\bb{\delta}}(\bb{x}) \, + \, \mathbb{E} _{\bb{z} \sim p(\bb{Z}) } \log (1-\Delta _{\bb{\delta}}(\Psi _{\bb{\gamma}} (\bb{z}) )). $$

A key insight into GANs is that we can interpret the above maximum as the loss with respect to $\bb{\gamma}$:

$$ L({\bb{\gamma}}) = \max _{\bb{\delta}} \, \mathbb{E} _{\bb{x} \sim p(\bb{X}) } \log \Delta _{\bb{\delta}}(\bb{x}) \, + \, \mathbb{E} _{\bb{z} \sim p(\bb{Z}) } \log (1-\Delta _{\bb{\delta}}(\Psi _{\bb{\gamma}} (\bb{z}) )). $$

This means that the generator's loss function trains together with the generator itself in an adversarial manner. In contrast, when training our VAE we used a fixed L2 norm as a data loss term.

Model Implementation

We'll now implement a Deep Convolutional GAN (DCGAN) model. See the DCGAN paper for architecture ideas and tips for training.

TODO: Implement the Discriminator class in the hw3/gan.py module. If you wish you can reuse the EncoderCNN class from the VAE model as the first part of the Discriminator.

In [16]:
import hw3.gan as gan

dsc = gan.Discriminator(in_size=x0[0].shape).to(device)
print(dsc)

d0 = dsc(x0)
print(d0.shape)

test.assertSequenceEqual(d0.shape, (1,1))
Discriminator(
  (feature_extractor): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU()
    (5): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (6): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): ReLU()
    (8): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU()
    (10): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (11): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): ReLU()
    (13): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (14): ReLU()
    (15): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (16): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (17): ReLU()
    (18): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (19): ReLU()
  )
  (classifier): Sequential(
    (0): Linear(in_features=8192, out_features=4, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.2)
    (3): Linear(in_features=4, out_features=1, bias=True)
  )
)
torch.Size([1, 1])

TODO: Implement the Generator class in the hw3/gan.py module. If you wish you can reuse the DecoderCNN class from the VAE model as the last part of the Generator.

In [17]:
z_dim = 128
gen = gan.Generator(z_dim, 4).to(device)
print(gen)

z = torch.randn(1, z_dim).to(device)
xr = gen(z)
print(xr.shape)

test.assertSequenceEqual(x0.shape, xr.shape)
Generator(
  (seq): Sequential(
    (0): ConvTranspose2d(128, 1024, kernel_size=(4, 4), stride=(2, 2), bias=False)
    (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU()
    (9): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU()
    (12): ConvTranspose2d(128, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (13): Tanh()
  )
)
torch.Size([1, 3, 64, 64])

Loss Implementation

Let's begin with the discriminator's loss function. Based on the above we can flip the sign and say we want to update the Discriminator's parameters $\bb{\delta}$ so that they minimize the expression $$

  • \mathbb{E} {\bb{x} \sim p(\bb{X}) } \log \Delta {\bb{\delta}}(\bb{x}) \, - \, \mathbb{E} {\bb{z} \sim p(\bb{Z}) } \log (1-\Delta {\bb{\delta}}(\Psi _{\bb{\gamma}} (\bb{z}) )). $$

We're using the Discriminator twice in this expression; once to classify data from the real data distribution and once again to classify generated data. Therefore our loss should be computed based on these two terms. Notice that since the discriminator returns a probability, we can formulate the above as two cross-entropy losses.

GANs are notoriously diffucult to train. One common trick for improving GAN stability during training is to make the classification labels noisy for the discriminator. This can be seen as a form of regularization, to help prevent the discriminator from overfitting.

We'll incorporate this idea into our loss function. Instead of labels being equal to 0 or 1, we'll make them "fuzzy", i.e. random numbers in the ranges $[0\pm\epsilon]$ and $[1\pm\epsilon]$.

TODO: Implement the discriminator_loss_fn() function in the hw3/gan.py module.

In [18]:
from hw3.gan import discriminator_loss_fn
torch.manual_seed(42)

y_data = torch.rand(10) * 10
y_generated = torch.rand(10) * 10

loss = discriminator_loss_fn(y_data, y_generated, data_label=1, label_noise=0.3)
print(loss)

test.assertAlmostEqual(loss.item(), 6.4808731, delta=1e-5)
tensor(6.4809)

Similarly, the generator's parameters $\bb{\gamma}$ should minimize the expression $$ -\mathbb{E} _{\bb{z} \sim p(\bb{Z}) } \log (\Delta _{\bb{\delta}}(\Psi _{\bb{\gamma}} (\bb{z}) )) $$

which can also be seen as a cross-entropy term.

TODO: Implement the generator_loss_fn() function in the hw3/gan.py module.

In [19]:
from hw3.gan import generator_loss_fn
torch.manual_seed(42)

y_generated = torch.rand(20) * 10

loss = generator_loss_fn(y_generated, data_label=1)
print(loss)

test.assertAlmostEqual(loss.item(), 0.0222969, delta=1e-5)
tensor(0.0223)

Sampling

Sampling from a GAN is straightforward, since it learns to generate data from an isotropic Gaussian latent space distribution.

There is an important nuance however. Sampling is required during the process of training the GAN, since we generate fake images to show the discriminator. As you'll seen in the next section, in some cases we'll need our samples to have gradients.

TODO: Implement the sample() method in the Generator class within the hw3/gan.py module.

In [20]:
samples = gen.sample(5, with_grad=False)
test.assertSequenceEqual(samples.shape, (5, *x0.shape[1:]))
test.assertIsNone(samples.grad_fn)
_ = plot.tensors_as_images(samples.cpu())

samples = gen.sample(5, with_grad=True)
test.assertSequenceEqual(samples.shape, (5, *x0.shape[1:]))
test.assertIsNotNone(samples.grad_fn)

Training

Training GANs is a bit different since we need to train two models simultaneously, each with it's own separate loss function and optimizer. We'll implement the training logic as a function that handles one batch of data and updates both the discriminator and the generator based on it.

As mentioned above, GANs are considered hard to train. To get some ideas and tips you can see this paper, this list of "GAN hacks" or just do it the hard way :)

TODO:

  1. Implement the train_batch function in the hw3/gan.py module.
  2. Tweak the hyperparameters in the part3_gan_hyperparam() function within the hw3/answers.py module.
In [43]:
import torch.optim as optim
from torch.utils.data import DataLoader
from hw3.answers import part3_gan_hyperparams

torch.manual_seed(42)

# Hyperparams
hp = part3_gan_hyperparams()
batch_size = hp['batch_size']
z_dim = hp['z_dim']

# Data
dl_train = DataLoader(ds_gwb, batch_size, shuffle=True)
im_size = ds_gwb[0][0].shape

# Model
dsc = gan.Discriminator(im_size).to(device)
gen = gan.Generator(z_dim, featuremap_size=4).to(device)

# Optimizer
def create_optimizer(model_params, opt_params):
    opt_params = opt_params.copy()
    optimizer_type = opt_params['type']
    opt_params.pop('type')
    return optim.__dict__[optimizer_type](model_params, **opt_params)
dsc_optimizer = create_optimizer(dsc.parameters(), hp['discriminator_optimizer'])
gen_optimizer = create_optimizer(gen.parameters(), hp['generator_optimizer'])

# Loss
def dsc_loss_fn(y_data, y_generated):
    return gan.discriminator_loss_fn(y_data, y_generated, hp['data_label'], hp['label_noise'])

def gen_loss_fn(y_generated):
    return gan.generator_loss_fn(y_generated, hp['data_label'])

# Training
checkpoint_file = 'checkpoints/gan'
checkpoint_file_final = f'{checkpoint_file}_final'
if os.path.isfile(f'{checkpoint_file}.pt'):
    os.remove(f'{checkpoint_file}.pt')

# Show hypers
print(hp)
{'batch_size': 32, 'z_dim': 128, 'data_label': 1, 'label_noise': 0.3, 'discriminator_optimizer': {'type': 'Adam', 'weight_decay': 0.001, 'betas': (0.5, 0.999), 'lr': 0.0001}, 'generator_optimizer': {'type': 'Adam', 'weight_decay': 0.001, 'betas': (0.5, 0.999), 'lr': 0.0001}}
In [54]:
import IPython.display
import tqdm
from hw3.gan import train_batch

num_epochs = 100

if os.path.isfile(f'{checkpoint_file_final}.pt'):
    print(f'*** Loading final checkpoint file {checkpoint_file_final} instead of training')
    num_epochs = 0
    gen = torch.load(f'{checkpoint_file_final}.pt', map_location=device)
    checkpoint_file = checkpoint_file_final

for epoch_idx in range(num_epochs):
    # We'll accumulate batch losses and show an average once per epoch.
    dsc_losses = []
    gen_losses = []
    print(f'--- EPOCH {epoch_idx+1}/{num_epochs} ---')
    
    with tqdm.tqdm(total=len(dl_train.batch_sampler), file=sys.stdout) as pbar:
        for batch_idx, (x_data, _) in enumerate(dl_train):
            x_data = x_data.to(device)
            dsc_loss, gen_loss = train_batch(
                dsc, gen,
                dsc_loss_fn, gen_loss_fn,
                dsc_optimizer, gen_optimizer,
                x_data)
            dsc_losses.append(dsc_loss)
            gen_losses.append(gen_loss)
            pbar.update()

    dsc_avg_loss, gen_avg_loss = np.mean(dsc_losses), np.mean(gen_losses)
    print(f'Discriminator loss: {dsc_avg_loss}')
    print(f'Generator loss:     {gen_avg_loss}')
        
    samples = gen.sample(5, with_grad=False)
    fig, _ = plot.tensors_as_images(samples.cpu(), figsize=(6,2))
    IPython.display.display(fig)
    plt.close(fig)
--- EPOCH 1/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.66it/s]
Discriminator loss: 0.2104850661228685
Generator loss:     5.232235277400298
--- EPOCH 2/100 ---
100%|██████████| 17/17 [00:04<00:00,  3.98it/s]
Discriminator loss: 0.23571953703375423
Generator loss:     5.371320724487305
--- EPOCH 3/100 ---
100%|██████████| 17/17 [00:04<00:00,  3.73it/s]
Discriminator loss: 0.22443616872324662
Generator loss:     6.22440432099735
--- EPOCH 4/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.70it/s]
Discriminator loss: 0.16487004333997474
Generator loss:     5.709871909197639
--- EPOCH 5/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.55it/s]
Discriminator loss: 0.19601186014273586
Generator loss:     5.630541941698859
--- EPOCH 6/100 ---
100%|██████████| 17/17 [00:04<00:00,  4.01it/s]
Discriminator loss: 0.19025547002606533
Generator loss:     5.7892750852248245
--- EPOCH 7/100 ---
100%|██████████| 17/17 [00:04<00:00,  4.20it/s]
Discriminator loss: 0.1917685342623907
Generator loss:     5.775272818172679
--- EPOCH 8/100 ---
100%|██████████| 17/17 [00:04<00:00,  4.02it/s]
Discriminator loss: 0.15585733489955172
Generator loss:     6.1006456262925095
--- EPOCH 9/100 ---
100%|██████████| 17/17 [00:04<00:00,  4.14it/s]
Discriminator loss: 0.19561865969615824
Generator loss:     5.675936656839707
--- EPOCH 10/100 ---
100%|██████████| 17/17 [00:04<00:00,  4.13it/s]
Discriminator loss: 0.20650859527728138
Generator loss:     6.099901732276468
--- EPOCH 11/100 ---
100%|██████████| 17/17 [00:04<00:00,  3.82it/s]
Discriminator loss: 0.17880899196161942
Generator loss:     6.0154531703275795
--- EPOCH 12/100 ---
100%|██████████| 17/17 [00:04<00:00,  3.98it/s]
Discriminator loss: 0.1536852005211746
Generator loss:     6.2796140278086945
--- EPOCH 13/100 ---
100%|██████████| 17/17 [00:04<00:00,  4.25it/s]
Discriminator loss: 0.14641955319572897
Generator loss:     6.12741086062263
--- EPOCH 14/100 ---
100%|██████████| 17/17 [00:04<00:00,  3.93it/s]
Discriminator loss: 0.18422164307797656
Generator loss:     6.318891553317799
--- EPOCH 15/100 ---
100%|██████████| 17/17 [00:04<00:00,  3.88it/s]
Discriminator loss: 0.1644474467372193
Generator loss:     6.447155559764189
--- EPOCH 16/100 ---
100%|██████████| 17/17 [00:04<00:00,  4.21it/s]
Discriminator loss: 0.1284379601916846
Generator loss:     6.709622635560877
--- EPOCH 17/100 ---
100%|██████████| 17/17 [00:04<00:00,  4.05it/s]
Discriminator loss: 0.13761782514698365
Generator loss:     6.5783331674688
--- EPOCH 18/100 ---
100%|██████████| 17/17 [00:04<00:00,  3.95it/s]
Discriminator loss: 0.2822633262942819
Generator loss:     6.204679741578944
--- EPOCH 19/100 ---
100%|██████████| 17/17 [00:04<00:00,  3.79it/s]
Discriminator loss: 0.13365542921511567
Generator loss:     6.124513990738812
--- EPOCH 20/100 ---
100%|██████████| 17/17 [00:04<00:00,  4.12it/s]
Discriminator loss: 0.17097877612447038
Generator loss:     6.5844942822175865
--- EPOCH 21/100 ---
100%|██████████| 17/17 [00:04<00:00,  4.04it/s]
Discriminator loss: 0.2048212929683573
Generator loss:     6.579410749323228
--- EPOCH 22/100 ---
100%|██████████| 17/17 [00:04<00:00,  3.96it/s]
Discriminator loss: 0.17134850368122845
Generator loss:     6.317334932439468
--- EPOCH 23/100 ---
100%|██████████| 17/17 [00:04<00:00,  4.11it/s]
Discriminator loss: 0.06852861712960635
Generator loss:     7.19891180711634
--- EPOCH 24/100 ---
100%|██████████| 17/17 [00:04<00:00,  4.22it/s]
Discriminator loss: 0.14982541343745062
Generator loss:     6.64442850561703
--- EPOCH 25/100 ---
100%|██████████| 17/17 [00:03<00:00,  5.07it/s]
Discriminator loss: 0.1717278466505163
Generator loss:     7.029782337300918
--- EPOCH 26/100 ---
100%|██████████| 17/17 [00:04<00:00,  4.17it/s]
Discriminator loss: 0.17784583995885708
Generator loss:     7.180853899787454
--- EPOCH 27/100 ---
100%|██████████| 17/17 [00:04<00:00,  3.89it/s]
Discriminator loss: 0.15919926105176702
Generator loss:     6.810590519624598
--- EPOCH 28/100 ---
100%|██████████| 17/17 [00:04<00:00,  3.85it/s]
Discriminator loss: 0.12100651202832952
Generator loss:     7.038385082693661
--- EPOCH 29/100 ---
100%|██████████| 17/17 [00:04<00:00,  4.09it/s]
Discriminator loss: 0.1525040772907874
Generator loss:     6.687205539030187
--- EPOCH 30/100 ---
100%|██████████| 17/17 [00:04<00:00,  3.85it/s]
Discriminator loss: 0.17617460663485177
Generator loss:     6.836430689867805
--- EPOCH 31/100 ---
100%|██████████| 17/17 [00:04<00:00,  4.15it/s]
Discriminator loss: 0.0880723926512634
Generator loss:     7.559782729429357
--- EPOCH 32/100 ---
100%|██████████| 17/17 [00:04<00:00,  3.92it/s]
Discriminator loss: 0.17680789354969473
Generator loss:     6.578231755424948
--- EPOCH 33/100 ---
100%|██████████| 17/17 [00:04<00:00,  3.81it/s]
Discriminator loss: 0.1517919345813639
Generator loss:     6.868827875922708
--- EPOCH 34/100 ---
100%|██████████| 17/17 [00:04<00:00,  3.79it/s]
Discriminator loss: 0.19488725307233193
Generator loss:     6.9007936645956605
--- EPOCH 35/100 ---
100%|██████████| 17/17 [00:04<00:00,  4.13it/s]
Discriminator loss: 0.16586365813718124
Generator loss:     7.40082207848044
--- EPOCH 36/100 ---
100%|██████████| 17/17 [00:04<00:00,  4.74it/s]
Discriminator loss: 0.14876083295573206
Generator loss:     6.559637013603659
--- EPOCH 37/100 ---
100%|██████████| 17/17 [00:03<00:00,  4.96it/s]
Discriminator loss: 0.11414158760624774
Generator loss:     7.437652812284582
--- EPOCH 38/100 ---
100%|██████████| 17/17 [00:04<00:00,  3.91it/s]
Discriminator loss: 0.21772026840378256
Generator loss:     6.340291331796085
--- EPOCH 39/100 ---
100%|██████████| 17/17 [00:04<00:00,  3.84it/s]
Discriminator loss: 0.18066350195337744
Generator loss:     6.721657360301299
--- EPOCH 40/100 ---
100%|██████████| 17/17 [00:04<00:00,  3.90it/s]
Discriminator loss: 0.1617299684268587
Generator loss:     6.625875697416418
--- EPOCH 41/100 ---
100%|██████████| 17/17 [00:04<00:00,  3.95it/s]
Discriminator loss: 0.19794127739527645
Generator loss:     6.804766739115996
--- EPOCH 42/100 ---
100%|██████████| 17/17 [00:04<00:00,  4.15it/s]
Discriminator loss: 0.17773300212095766
Generator loss:     7.329222679138184
--- EPOCH 43/100 ---
100%|██████████| 17/17 [00:04<00:00,  4.09it/s]
Discriminator loss: 0.22234640796394908
Generator loss:     6.2571783346288345
--- EPOCH 44/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.62it/s]
Discriminator loss: 0.18086633929873214
Generator loss:     6.378472636727726
--- EPOCH 45/100 ---
100%|██████████| 17/17 [00:04<00:00,  4.09it/s]
Discriminator loss: 0.21271933461813367
Generator loss:     6.16764357510735
--- EPOCH 46/100 ---
100%|██████████| 17/17 [00:04<00:00,  4.07it/s]
Discriminator loss: 0.1407627404174384
Generator loss:     6.7020626909592576
--- EPOCH 47/100 ---
100%|██████████| 17/17 [00:04<00:00,  4.07it/s]
Discriminator loss: 0.11332968490965226
Generator loss:     6.44600181018605
--- EPOCH 48/100 ---
100%|██████████| 17/17 [00:04<00:00,  4.18it/s]
Discriminator loss: 0.16300714750062018
Generator loss:     6.320674924289479
--- EPOCH 49/100 ---
100%|██████████| 17/17 [00:04<00:00,  3.99it/s]
Discriminator loss: 0.1813364186707665
Generator loss:     7.128478386822869
--- EPOCH 50/100 ---
100%|██████████| 17/17 [00:04<00:00,  3.85it/s]
Discriminator loss: 0.23097476582316792
Generator loss:     6.670736172619988
--- EPOCH 51/100 ---
100%|██████████| 17/17 [00:04<00:00,  4.02it/s]
Discriminator loss: 0.09786212477175628
Generator loss:     6.791717921986299
--- EPOCH 52/100 ---
100%|██████████| 17/17 [00:04<00:00,  4.14it/s]
Discriminator loss: 0.18331952568362742
Generator loss:     6.81852473932154
--- EPOCH 53/100 ---
100%|██████████| 17/17 [00:03<00:00,  5.48it/s]
Discriminator loss: 0.2219028074075194
Generator loss:     6.671722145641551
--- EPOCH 54/100 ---
100%|██████████| 17/17 [00:03<00:00,  5.44it/s]
Discriminator loss: 0.22358783422147527
Generator loss:     6.139595171984504
--- EPOCH 55/100 ---
100%|██████████| 17/17 [00:04<00:00,  4.07it/s]
Discriminator loss: 0.14912109000279622
Generator loss:     6.118021291844985
--- EPOCH 56/100 ---
100%|██████████| 17/17 [00:04<00:00,  4.36it/s]
Discriminator loss: 0.12866726430023417
Generator loss:     6.610690229079303
--- EPOCH 57/100 ---
100%|██████████| 17/17 [00:04<00:00,  4.08it/s]
Discriminator loss: 0.16428225388860002
Generator loss:     6.939194258521585
--- EPOCH 58/100 ---
100%|██████████| 17/17 [00:04<00:00,  4.06it/s]
Discriminator loss: 0.20538907454294317
Generator loss:     6.709279705496395
--- EPOCH 59/100 ---
100%|██████████| 17/17 [00:04<00:00,  4.14it/s]
Discriminator loss: 0.09481469224042752
Generator loss:     7.229408684898825
--- EPOCH 60/100 ---
100%|██████████| 17/17 [00:04<00:00,  4.08it/s]
Discriminator loss: 0.11453177516951281
Generator loss:     7.147596611696131
--- EPOCH 61/100 ---
100%|██████████| 17/17 [00:04<00:00,  4.13it/s]
Discriminator loss: 0.1828874570920187
Generator loss:     6.993250678567326
--- EPOCH 62/100 ---
100%|██████████| 17/17 [00:04<00:00,  4.19it/s]
Discriminator loss: 0.12596495055100498
Generator loss:     7.3126621807322785
--- EPOCH 63/100 ---
100%|██████████| 17/17 [00:04<00:00,  4.43it/s]
Discriminator loss: 0.17658651707803502
Generator loss:     7.251228052027085
--- EPOCH 64/100 ---
100%|██████████| 17/17 [00:04<00:00,  4.30it/s]
Discriminator loss: 0.10668047254576403
Generator loss:     7.212910539963666
--- EPOCH 65/100 ---
100%|██████████| 17/17 [00:04<00:00,  4.26it/s]
Discriminator loss: 0.08577550016343594
Generator loss:     7.160778606639189
--- EPOCH 66/100 ---
100%|██████████| 17/17 [00:04<00:00,  4.22it/s]
Discriminator loss: 0.12293543053023956
Generator loss:     7.144527182859533
--- EPOCH 67/100 ---
100%|██████████| 17/17 [00:04<00:00,  4.49it/s]
Discriminator loss: 0.16996412789996931
Generator loss:     7.528690534479478
--- EPOCH 68/100 ---
100%|██████████| 17/17 [00:04<00:00,  4.19it/s]
Discriminator loss: 0.07860983338426142
Generator loss:     7.814025037428912
--- EPOCH 69/100 ---
100%|██████████| 17/17 [00:03<00:00,  5.29it/s]
Discriminator loss: 0.24334019137655988
Generator loss:     6.152402120478013
--- EPOCH 70/100 ---
100%|██████████| 17/17 [00:03<00:00,  4.94it/s]
Discriminator loss: 0.14791723077788071
Generator loss:     7.22919433257159
--- EPOCH 71/100 ---
100%|██████████| 17/17 [00:04<00:00,  4.18it/s]
Discriminator loss: 0.31150157661998973
Generator loss:     6.812634103438434
--- EPOCH 72/100 ---
100%|██████████| 17/17 [00:04<00:00,  4.06it/s]
Discriminator loss: 0.10162635542014066
Generator loss:     7.010197695563821
--- EPOCH 73/100 ---
100%|██████████| 17/17 [00:04<00:00,  4.14it/s]
Discriminator loss: 0.1518172966864179
Generator loss:     7.004276696373434
--- EPOCH 74/100 ---
100%|██████████| 17/17 [00:04<00:00,  4.16it/s]
Discriminator loss: 0.18425019282628508
Generator loss:     7.0809484650106995
--- EPOCH 75/100 ---
100%|██████████| 17/17 [00:04<00:00,  4.14it/s]
Discriminator loss: 0.15221571615513632
Generator loss:     7.240997538847082
--- EPOCH 76/100 ---
100%|██████████| 17/17 [00:04<00:00,  4.05it/s]
Discriminator loss: 0.15491124527419314
Generator loss:     6.776064620298498
--- EPOCH 77/100 ---
100%|██████████| 17/17 [00:03<00:00,  5.28it/s]
Discriminator loss: 0.139634500750724
Generator loss:     7.009880262262681
--- EPOCH 78/100 ---
100%|██████████| 17/17 [00:03<00:00,  4.30it/s]
Discriminator loss: 0.1963846289059695
Generator loss:     7.163768768310547
--- EPOCH 79/100 ---
100%|██████████| 17/17 [00:04<00:00,  4.06it/s]
Discriminator loss: 0.1572239850373829
Generator loss:     6.720290604759665
--- EPOCH 80/100 ---
100%|██████████| 17/17 [00:04<00:00,  3.98it/s]
Discriminator loss: 0.12188968719804988
Generator loss:     7.252727592692656
--- EPOCH 81/100 ---
100%|██████████| 17/17 [00:04<00:00,  4.30it/s]
Discriminator loss: 0.12670191300704198
Generator loss:     6.667434972875259
--- EPOCH 82/100 ---
100%|██████████| 17/17 [00:04<00:00,  4.10it/s]
Discriminator loss: 0.11248969527728417
Generator loss:     7.198457044713638
--- EPOCH 83/100 ---
100%|██████████| 17/17 [00:04<00:00,  4.20it/s]
Discriminator loss: 0.11614164206034996
Generator loss:     7.2386157933403465
--- EPOCH 84/100 ---
100%|██████████| 17/17 [00:04<00:00,  4.06it/s]
Discriminator loss: 0.15769422865089247
Generator loss:     7.224797809825224
--- EPOCH 85/100 ---
100%|██████████| 17/17 [00:04<00:00,  4.06it/s]
Discriminator loss: 0.11265083753010806
Generator loss:     8.062072473413805
--- EPOCH 86/100 ---
100%|██████████| 17/17 [00:04<00:00,  4.05it/s]
Discriminator loss: 0.1909891370245639
Generator loss:     7.273121020373176
--- EPOCH 87/100 ---
100%|██████████| 17/17 [00:04<00:00,  3.95it/s]
Discriminator loss: 0.0711965941111831
Generator loss:     7.398668092839858
--- EPOCH 88/100 ---
100%|██████████| 17/17 [00:04<00:00,  4.08it/s]
Discriminator loss: 0.11818586498060647
Generator loss:     7.55328192430384
--- EPOCH 89/100 ---
100%|██████████| 17/17 [00:04<00:00,  3.90it/s]
Discriminator loss: 0.12663840223103762
Generator loss:     8.232169067158418
--- EPOCH 90/100 ---
100%|██████████| 17/17 [00:04<00:00,  3.72it/s]
Discriminator loss: 0.09800472110509872
Generator loss:     7.852265947005328
--- EPOCH 91/100 ---
100%|██████████| 17/17 [00:04<00:00,  3.62it/s]
Discriminator loss: 0.19031105826006217
Generator loss:     7.519415602964513
--- EPOCH 92/100 ---
100%|██████████| 17/17 [00:04<00:00,  3.87it/s]
Discriminator loss: 0.11785269178011838
Generator loss:     7.5413768151227165
--- EPOCH 93/100 ---
100%|██████████| 17/17 [00:04<00:00,  4.01it/s]
Discriminator loss: 0.18558044841184335
Generator loss:     8.440010014702292
--- EPOCH 94/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.82it/s]
Discriminator loss: 0.20992926858803806
Generator loss:     8.222050610710593
--- EPOCH 95/100 ---
100%|██████████| 17/17 [00:04<00:00,  3.95it/s]
Discriminator loss: 0.20124570784323356
Generator loss:     7.81480006610646
--- EPOCH 96/100 ---
100%|██████████| 17/17 [00:04<00:00,  3.83it/s]
Discriminator loss: 0.1217698459239567
Generator loss:     7.88374912037569
--- EPOCH 97/100 ---
100%|██████████| 17/17 [00:04<00:00,  4.09it/s]
Discriminator loss: 0.18535523883560123
Generator loss:     7.8992780236636895
--- EPOCH 98/100 ---
100%|██████████| 17/17 [00:04<00:00,  4.01it/s]
Discriminator loss: 0.053624547798843944
Generator loss:     7.6726537872763245
--- EPOCH 99/100 ---
100%|██████████| 17/17 [00:04<00:00,  4.21it/s]
Discriminator loss: 0.12148237206480082
Generator loss:     7.971524743472829
--- EPOCH 100/100 ---
100%|██████████| 17/17 [00:04<00:00,  4.18it/s]
Discriminator loss: 0.17740684627171824
Generator loss:     7.747653961181641
In [53]:
# Plot images from best or last model
if os.path.isfile(f'{checkpoint_file}.pt'):
    gen = torch.load(f'{checkpoint_file}.pt', map_location=device)
print('*** Images Generated from best model:')
samples = gen.sample(n=15, with_grad=False).cpu()
fig, _ = plot.tensors_as_images(samples, nrows=3, figsize=(6,6))
*** Images Generated from best model:

Questions

TODO Answer the following questions. Write your answers in the appropriate variables in the module hw3/answers.py.

In [46]:
from cs236605.answers import display_answer
import hw3.answers

Question 1

Explain in detail why during training we sometimes need to maintain gradients when sampling from the GAN, and other times we don't. When are they maintained and why? When are they discarded and why?

In [47]:
display_answer(hw3.answers.part3_q1)

Your answer:

We have to maintain the gradient when sampling from the GAN when we are sampling in the batch training function. The reason is that when we are training we want the gradient so that we can optimize the result of the generator. In all other occasions we do not maintain the gradient so that we will not change its value and ruine the training.

Question 2

  1. When training a GAN to generate images, should we decide to stop training solely based on the fact that the Generator loss is below some threshold? Why or why not?

  2. What does it mean if the discriminator loss remains at a constant value while the generator loss decreases?

In [48]:
display_answer(hw3.answers.part3_q2)

Your answer: We can't decide to stop training based on the generator loss being bellow a certain threshold because the loos of the generator and the loss of the descriminator are connected. We can imagine a situation where the loss of the generator is very low (so we would think to stop training) but in the next batch the descriminator will sudenlly improve and find new differences between the real and fake images therefore the loss of the generator will go back up. If we get into a situation where the loss of the descriminator is constant but the loos of the generator keeps improving then we are in a situation where the descriminator can no longer tell the difference between the real and fake images but the generator keeps making the images better and better in comparison to the real ones.

Question 2

Compare the results you got when generating images with the VAE to the GAN results. What's the main difference and what's causing it?

In [49]:
display_answer(hw3.answers.part3_q3)

Your answer: The GAN is divided in two parts that compete against each other. The generator tries to construct a realistic image, while the discriminator tries to find out if the image is real or not. This permits a better training in order to obtain images that are more similar to the specific real ones that come from the training set. For this reason, the results of the GAN are better overall images in terms of the background for example. The goal is to reconstruct images that look as similar as possible to the dataset. On the other hand, the VAE aims to sample images from the label space, that is to reconstruct images from a prior distribution. We assume that the instances can be reconstruct from a latent space with smaller dimension and we want to learn through the VAE the main characteristics of the images in the dataset. Thus, even if the pictures are blurry, it captures the main characteristics of a face in each sample.